瀏覽代碼

归还节点并行度

martin 2 年之前
父節點
當前提交
c596e57d31

+ 36 - 35
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/manager/TaskManager.java

@@ -89,6 +89,7 @@ public class TaskManager {
     @SneakyThrows
     @Transactional
     public boolean isProjectCompleted(PrefixTO redisPrefix, String projectId, String taskId, String state, String podName) {
+        String nodeName = projectUtil.getNodeNameOfPod(podName);
         if ("Running".equals(state)) {  // 运行中的 pod 无需删除
             // 将运行中的任务的 pod 名称放入 redis
             stringRedisTemplate.opsForValue().set(redisPrefix.getTaskPodKey(), podName);
@@ -96,11 +97,10 @@ public class TaskManager {
             log.info("TaskManager--state 修改任务 " + taskId + " 的状态为 " + state + ",pod 名称为:" + podName);
             taskMapper.updateStateWithStartTime(taskId, state, TimeUtil.getNowForMysql());
             return false;
-        } else { // 结束的 pod 都直接删除
+        } else { // 结束的 pod 都直接删除,并判断项目是否完成
             // -------------------------------- 处理状态 --------------------------------
             //TODO 暂时不用重试操作
             try {
-                KubernetesUtil.deletePod(apiClient, kubernetesNamespace, podName);
                 log.info("TaskManager--state 修改任务 " + taskId + "的状态为 " + state + ",pod 名称为:" + podName + ",并删除 pod。");
                 if ("Aborted".equals(state)) {
                     String minioPathOfErrorLog = resultPathMinio + projectId + "/" + taskId + "error.log";
@@ -129,21 +129,21 @@ public class TaskManager {
                 } else if ("PendingAnalysis".equals(state)) {
                     taskMapper.updateSuccessStateWithStopTime(taskId, state, TimeUtil.getNowForMysql());
                 }
-            } catch (io.kubernetes.client.openapi.ApiException apiException) {
-                log.error("TaskManager--isCompleted pod " + podName + " 已经被手动删除,该项目可能已经失败或删除。");
-                return false;
-            }
-            // -------------------------------- 判断项目是否结束 --------------------------------
-            ProjectMessageDTO projectMessageDTO = JsonUtil.jsonToBean(stringRedisTemplate.opsForValue().get(redisPrefix.getProjectRunningKey()), ProjectMessageDTO.class);
-            int taskTotal = projectMessageDTO.getTaskTotal();
-            int taskCompleted = projectMessageDTO.getTaskCompleted();
-            log.info("TaskManager--isProjectCompleted 项目 " + projectId + " 完成进度为:" + (taskCompleted + 1) + "/" + taskTotal);
-            if (taskCompleted + 1 == taskTotal) {
-                return true;
-            } else {
-                projectMessageDTO.setTaskCompleted(taskCompleted + 1);
-                stringRedisTemplate.opsForValue().set(redisPrefix.getProjectRunningKey(), JsonUtil.beanToJson(projectMessageDTO));
-                createNextPod(projectId, podName);  // 项目没有完成则启动下一个 pod
+                // -------------------------------- 判断项目是否结束 --------------------------------
+                ProjectMessageDTO projectMessageDTO = JsonUtil.jsonToBean(stringRedisTemplate.opsForValue().get(redisPrefix.getProjectRunningKey()), ProjectMessageDTO.class);
+                int taskTotal = projectMessageDTO.getTaskTotal();
+                int taskCompleted = projectMessageDTO.getTaskCompleted();
+                log.info("TaskManager--isProjectCompleted 项目 " + projectId + " 完成进度为:" + (taskCompleted + 1) + "/" + taskTotal);
+                if (taskCompleted + 1 == taskTotal) {
+                    return true;
+                } else {
+                    projectMessageDTO.setTaskCompleted(taskCompleted + 1);
+                    stringRedisTemplate.opsForValue().set(redisPrefix.getProjectRunningKey(), JsonUtil.beanToJson(projectMessageDTO));
+                    createNextPod(nodeName, projectId, podName);  // 项目没有完成则启动下一个 pod,同时删除上一个 pod
+                    return false;
+                }
+            } catch (Exception exception) {
+                log.error("TaskManager--isCompleted pod " + podName + " 已经被手动删除,该项目可能已经失败或删除。", exception);
                 return false;
             }
         }
@@ -152,17 +152,21 @@ public class TaskManager {
     /**
      * 更改一个名字继续启动
      *
-     * @param projectId 项目 id
-     * @param lastPodName   项目名称
+     * @param projectId   项目 id
+     * @param nodeName    运行 pod 的节点名称
+     * @param lastPodName 项目名称
      */
     @SneakyThrows
-    public void createNextPod(String projectId, String lastPodName) {
+    public void createNextPod(String projectId, String nodeName, String lastPodName) {
+        //1 删除上一个 pod 和 redis 键值对
+        KubernetesUtil.deletePod(apiClient, kubernetesNamespace, lastPodName);
+        stringRedisTemplate.delete(projectUtil.getNodeNameOfPod(lastPodName));
         String lastPodString = FileUtil.read(podYamlDirectory + lastPodName + ".yaml");
         String nextPodName = "project-" + projectId + "-" + StringUtil.getRandomUUID();
         String nextPodString = lastPodString.replace(lastPodName, nextPodName); // pod 名称包括 projectId 和 随机字符串
         String nextPodFileName = nextPodName + ".yaml";     // 实际执行 pod 的文件名称
         log.info("TaskManager--createNextPod 创建项目 " + projectId + " 的下一个 pod。");
-        projectUtil.createPod(nextPodString, nextPodFileName);
+        projectUtil.createPod(nodeName, nextPodString, nextPodFileName);
     }
 
     public void prepareScore(String projectRunningKey) {
@@ -436,21 +440,18 @@ public class TaskManager {
 //        SshUtil.stop(clientKafka, sessionKafka);
 
 
-        Map<String, Integer> nodeMap = projectUtil.getNodeMap();
-        List<String> podList = KubernetesUtil.getPodByPrefix(apiClient, kubernetesNamespace, "project-" + projectId);
-        for (String tempPodName : podList) {
-            // 删除该 project 下的所有 pod
-            KubernetesUtil.deletePod(apiClient, kubernetesNamespace, tempPodName);
-            // 归还并行度
-            String tempNodeName = stringRedisTemplate.opsForValue().get("pod:" + tempPodName + ":node");
-            stringRedisTemplate.delete("pod:" + tempPodName + ":node");
-            int restParallelism = nodeMap.get(tempNodeName);
-            nodeMap.put(tempNodeName, restParallelism + 1);
+        // 归还并行度
+        Set<String> nodeOfPodKeySet = stringRedisTemplate.keys("pod:project-" + projectId);
+        for (String nodeOfPodKey : nodeOfPodKeySet) {
+            String podName = nodeOfPodKey.split(":")[1];
+            String nodeName = projectUtil.getNodeNameOfPod(podName);
+            // 删除 pod
+            KubernetesUtil.deletePod(apiClient, kubernetesNamespace, podName);
+            // 删除 redis key
+            projectUtil.deleteNodeNameOfPod(podName);
+            // 节点并行度加一
+            projectUtil.addOneParallelismToNode(nodeName);
         }
-        nodeMap.forEach((tempNodeName, tempParallelism) -> {
-            String restParallelismKey = "node:" + tempNodeName + ":parallelism";
-            stringRedisTemplate.opsForValue().set(restParallelismKey, tempParallelism + "");
-        });
 
         // 删除 redis 中的键值对
         Set<String> keys = stringRedisTemplate.keys(redisPrefix.getProjectRunningKey() + "*");

+ 2 - 7
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/service/ProjectService.java

@@ -89,7 +89,6 @@ public class ProjectService {
     @Transactional
     public void prepare(Map<String, Integer> nodeMap, ProjectMessageDTO projectMessageDTO, String clusterPrefix) {
         String projectId = projectMessageDTO.getProjectId();
-        String projectType = projectMessageDTO.getType();
         //1 将指定 node 的并行度减少
         nodeMap.keySet().forEach(nodeName -> {
             int parallelismToUse = nodeMap.get(nodeName);
@@ -278,7 +277,7 @@ public class ProjectService {
     /**
      * 将 master 节点设置成镜像仓库,导入镜像的同时 commit 到仓库当中,供其他节点 pull
      *
-     * @param projectId 项目 id
+     * @param projectId   项目 id
      * @param algorithmId 算法 id
      * @return 镜像名称
      */
@@ -391,15 +390,11 @@ public class ProjectService {
 
         nodeMap.forEach((nodeName, parallelism) -> {
             String podName = "project-" + projectId + "-" + StringUtil.getRandomUUID();
-            stringRedisTemplate.opsForValue().set("pod:" + podName + ":node", nodeName);    // 将 pod 运行在哪个 node 上记录到 redis
-            String tempPodFileNameOfProject = podName + ".yaml";     // 模板文件名称
             // -------------------------------- Comment --------------------------------
             String tempReplace4 = podTemplateStringOfProject.replace("pod-name", podName); // pod 名称包括 projectId 和 随机字符串
             String tempPodString = tempReplace4.replace("node-name", nodeName);     // 指定 pod 运行节点
             log.info("ProjectService--createPod 在节点 " + nodeName + " 开始执行 pod:" + tempPodString);
-            projectUtil.createPod(tempPodString,tempPodFileNameOfProject);
-
-
+            projectUtil.createPod(nodeName, podName, tempPodString);
 //            V1Pod v1Pod;
 //            try {
 //                FileUtil.writeStringToLocalFile(tempPodString, podYamlDirectory + tempPodFileNameOfProject);

+ 30 - 8
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/util/ProjectUtil.java

@@ -58,8 +58,24 @@ public class ProjectUtil {
     @Resource
     StringRedisTemplate stringRedisTemplate;
 
+
+    public String getParallelismOfNode(String nodeName) {
+        return stringRedisTemplate.opsForValue().get("node:" + nodeName + ":parallelism");
+    }
+
+    public void deleteNodeNameOfPod(String podName) {
+        stringRedisTemplate.delete("pod:" + podName + ":node");
+    }
+
+    public String getNodeNameOfPod(String podName) {
+        return stringRedisTemplate.opsForValue().get("pod:" + podName + ":node");
+    }
+
+
     @SneakyThrows
-    public void createPod(String podYamlContent, String podYamlName) {
+    public void createPod(String nodeName, String podName, String podYamlContent) {
+        stringRedisTemplate.opsForValue().set("pod:" + podName + ":node", nodeName);    // 将 pod 运行在哪个 node 上记录到 redis
+        String podYamlName = podName + ".yaml";     // 模板文件名称
         String podYamlPath = podYamlDirectory + podYamlName;
         FileUtil.writeStringToLocalFile(podYamlContent, podYamlPath);
         KubernetesUtil.applyYaml(hostname, username, password, podYamlPath);
@@ -68,6 +84,7 @@ public class ProjectUtil {
 //        KubernetesUtil.createPod(apiClient, kubernetesNamespace, v1Pod);
     }
 
+
     public String getProjectTypeByProjectId(String projectId) {
         String projectType = null;
         ProjectPO manualProjectPO = manualProjectMapper.selectById(projectId);
@@ -116,10 +133,12 @@ public class ProjectUtil {
      */
     public Map<String, Integer> getNodeMap() {
         List<KubernetesNodeTO> initialNodeList = kubernetesConfiguration.getNodeList(); // 预设并行度的节点列表
+        log.info("ProjectUtil--getNodeMap 预设并行度的节点列表为:" + initialNodeList);
         Map<String, Integer> resultNodeMap = new HashMap<>();    // 用于执行的节点映射(节点名,并行度)
-        for (KubernetesNodeTO kubernetesNodeTO : initialNodeList) {
-            String nodeName = kubernetesNodeTO.getName();
-            int maxParallelism = kubernetesNodeTO.getMaxParallelism();
+        for (KubernetesNodeTO kubernetesNodeSource : initialNodeList) {
+            KubernetesNodeTO kubernetesNodeCopy = kubernetesNodeSource.clone();
+            String nodeName = kubernetesNodeCopy.getName();
+            int maxParallelism = kubernetesNodeCopy.getMaxParallelism();
             String restParallelismKey = "node:" + nodeName + ":parallelism";
             String restParallelismString = stringRedisTemplate.opsForValue().get(restParallelismKey);
             int restParallelism;
@@ -128,10 +147,10 @@ public class ProjectUtil {
                 stringRedisTemplate.opsForValue().set(restParallelismKey, restParallelism + "");
             } else {
                 restParallelism = Integer.parseInt(restParallelismString);
-                kubernetesNodeTO.setMaxParallelism(restParallelism);
             }
             resultNodeMap.put(nodeName, restParallelism);
         }
+        log.info("ProjectUtil--getNodeMapToUse 剩余并行度的节点列表为:" + resultNodeMap);
         return resultNodeMap;
     }
 
@@ -147,7 +166,6 @@ public class ProjectUtil {
         log.info("ProjectUtil--getNodeMapToUse 预设并行度的节点列表为:" + initialNodeList);
         // 遍历所有节点,获取还有剩余并行度的节点
         List<KubernetesNodeTO> restNodeList = new ArrayList<>();    // 剩余并行度的节点列表
-        int restParallelismSum = 0;
         for (KubernetesNodeTO kubernetesNodeSource : initialNodeList) {
             KubernetesNodeTO kubernetesNodeCopy = kubernetesNodeSource.clone();
             String nodeName = kubernetesNodeCopy.getName();   // 节点名称
@@ -164,14 +182,12 @@ public class ProjectUtil {
                 kubernetesNodeCopy.setMaxParallelism(restParallelism);
             }
             if (restParallelism > 0) {
-                restParallelismSum += restParallelism;
                 restNodeList.add(kubernetesNodeCopy);
             }
         }
         log.info("ProjectUtil--getNodeMapToUse 剩余并行度的节点列表为:" + restNodeList);
         Map<String, Integer> resultNodeMap = new HashMap<>();    // 用于执行的节点映射(节点名,并行度)
         if (!CollectionUtil.isEmpty(restNodeList)) {
-
             if (restNodeList.size() == 1) {
                 KubernetesNodeTO tempNode = restNodeList.get(0);
                 String tempNodeName = tempNode.getName();
@@ -357,4 +373,10 @@ public class ProjectUtil {
                 .build();
     }
 
+    public void addOneParallelismToNode(String nodeName) {
+        String key = "node:" + nodeName + ":parallelism";
+        String parallelismString = stringRedisTemplate.opsForValue().get(key);
+        int parallelism = Integer.parseInt(parallelismString);
+        stringRedisTemplate.opsForValue().set(key, (parallelism + 1) + "");
+    }
 }