|
@@ -246,7 +246,7 @@ public class ProjectUtil {
|
|
Map<String, Integer> resultNodeMap = new HashMap<>(); // 用于执行的节点映射(节点名,并行度)
|
|
Map<String, Integer> resultNodeMap = new HashMap<>(); // 用于执行的节点映射(节点名,并行度)
|
|
for (GpuNodeEntity kubernetesNodeSource : initialNodeList) {
|
|
for (GpuNodeEntity kubernetesNodeSource : initialNodeList) {
|
|
GpuNodeEntity kubernetesNodeCopy = kubernetesNodeSource.clone();
|
|
GpuNodeEntity kubernetesNodeCopy = kubernetesNodeSource.clone();
|
|
- String nodeName = kubernetesNodeCopy.getName();
|
|
|
|
|
|
+ String nodeName = kubernetesNodeCopy.getHostname();
|
|
int maxParallelism = kubernetesNodeCopy.getParallelism();
|
|
int maxParallelism = kubernetesNodeCopy.getParallelism();
|
|
String restParallelismKey = "gpu-node:" + nodeName + ":parallelism";
|
|
String restParallelismKey = "gpu-node:" + nodeName + ":parallelism";
|
|
String restParallelismString = stringRedisTemplate.opsForValue().get(restParallelismKey);
|
|
String restParallelismString = stringRedisTemplate.opsForValue().get(restParallelismKey);
|
|
@@ -275,7 +275,7 @@ public class ProjectUtil {
|
|
List<GpuNodeEntity> restNodeList = new ArrayList<>(); // 剩余并行度的节点列表
|
|
List<GpuNodeEntity> restNodeList = new ArrayList<>(); // 剩余并行度的节点列表
|
|
for (GpuNodeEntity kubernetesNodeSource : initialNodeList) {
|
|
for (GpuNodeEntity kubernetesNodeSource : initialNodeList) {
|
|
GpuNodeEntity kubernetesNodeCopy = kubernetesNodeSource.clone();
|
|
GpuNodeEntity kubernetesNodeCopy = kubernetesNodeSource.clone();
|
|
- String nodeName = kubernetesNodeCopy.getName(); // 节点名称
|
|
|
|
|
|
+ String nodeName = kubernetesNodeCopy.getHostname(); // 节点名称
|
|
int maxParallelism = kubernetesNodeCopy.getParallelism();
|
|
int maxParallelism = kubernetesNodeCopy.getParallelism();
|
|
String restParallelismString = stringRedisTemplate.opsForValue().get("gpu-node:" + nodeName + ":parallelism");// 获取节点剩余并行度的 key
|
|
String restParallelismString = stringRedisTemplate.opsForValue().get("gpu-node:" + nodeName + ":parallelism");// 获取节点剩余并行度的 key
|
|
// -------------------------------- Comment --------------------------------
|
|
// -------------------------------- Comment --------------------------------
|
|
@@ -308,7 +308,7 @@ public class ProjectUtil {
|
|
List<GpuNodeEntity> restNodeList = new ArrayList<>(); // 剩余并行度的节点列表
|
|
List<GpuNodeEntity> restNodeList = new ArrayList<>(); // 剩余并行度的节点列表
|
|
for (GpuNodeEntity kubernetesNodeSource : initialNodeList) {
|
|
for (GpuNodeEntity kubernetesNodeSource : initialNodeList) {
|
|
GpuNodeEntity kubernetesNodeCopy = kubernetesNodeSource.clone();
|
|
GpuNodeEntity kubernetesNodeCopy = kubernetesNodeSource.clone();
|
|
- String nodeName = kubernetesNodeCopy.getName(); // 节点名称
|
|
|
|
|
|
+ String nodeName = kubernetesNodeCopy.getHostname(); // 节点名称
|
|
int maxParallelism = kubernetesNodeCopy.getParallelism();
|
|
int maxParallelism = kubernetesNodeCopy.getParallelism();
|
|
String restParallelismString = stringRedisTemplate.opsForValue().get("gpu-node:" + nodeName + ":parallelism");// 获取节点剩余并行度的 key
|
|
String restParallelismString = stringRedisTemplate.opsForValue().get("gpu-node:" + nodeName + ":parallelism");// 获取节点剩余并行度的 key
|
|
// -------------------------------- Comment --------------------------------
|
|
// -------------------------------- Comment --------------------------------
|
|
@@ -329,7 +329,7 @@ public class ProjectUtil {
|
|
if (!CollectionUtil.isEmpty(restNodeList)) {
|
|
if (!CollectionUtil.isEmpty(restNodeList)) {
|
|
if (restNodeList.size() == 1) {
|
|
if (restNodeList.size() == 1) {
|
|
GpuNodeEntity tempNode = restNodeList.get(0);
|
|
GpuNodeEntity tempNode = restNodeList.get(0);
|
|
- String tempNodeName = tempNode.getName();
|
|
|
|
|
|
+ String tempNodeName = tempNode.getHostname();
|
|
int tempParallelism = tempNode.getParallelism();
|
|
int tempParallelism = tempNode.getParallelism();
|
|
resultNodeMap.put(tempNodeName, Math.min(tempParallelism, parallelism));
|
|
resultNodeMap.put(tempNodeName, Math.min(tempParallelism, parallelism));
|
|
}
|
|
}
|
|
@@ -338,7 +338,7 @@ public class ProjectUtil {
|
|
// 每次降序排序都取剩余并行度最大的一个。
|
|
// 每次降序排序都取剩余并行度最大的一个。
|
|
restNodeList.sort((o1, o2) -> o2.getParallelism() - o1.getParallelism());
|
|
restNodeList.sort((o1, o2) -> o2.getParallelism() - o1.getParallelism());
|
|
GpuNodeEntity tempNode = restNodeList.get(0);
|
|
GpuNodeEntity tempNode = restNodeList.get(0);
|
|
- String tempNodeName = tempNode.getName();
|
|
|
|
|
|
+ String tempNodeName = tempNode.getHostname();
|
|
int tempParallelism = tempNode.getParallelism();
|
|
int tempParallelism = tempNode.getParallelism();
|
|
if (tempParallelism > 0) {
|
|
if (tempParallelism > 0) {
|
|
tempNode.setParallelism(tempParallelism - 1);
|
|
tempNode.setParallelism(tempParallelism - 1);
|