浏览代码

将项目运行节点拆分成GPU节点和CPU节点

LingxinMeng 2 年之前
父节点
当前提交
989ac709c9

+ 1 - 1
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/configuration/constant/ConstantConfiguration.java

@@ -6,7 +6,7 @@ import org.springframework.context.annotation.Configuration;
 
 @Data
 @Configuration
-@ConfigurationProperties(prefix = "prefix")
+@ConfigurationProperties(prefix = "constant")
 public class ConstantConfiguration {
     private String temporaryDirectory;
     private String uploadOsgbUrl;

+ 6 - 5
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/consumer/ProjectConsumer.java

@@ -100,6 +100,7 @@ public class ProjectConsumer {
         //* -------------------------------- 0 读取消息,创建临时目录 --------------------------------
         String projectId = projectMessageDTO.getProjectId();                // 手动执行项目 id 或 自动执行子项目 id
         String projectType = projectMessageDTO.getType();                   // 项目类型
+        String isChoiceGpu = projectUtil.getIsChoiceGpuByProjectId(projectId);
         try {
             String modelType = projectMessageDTO.getModelType();                // 模型类型,1 动力学模型 2 carsim模型
             String packageId = projectMessageDTO.getScenePackageId();           // 场景测试包 id
@@ -283,7 +284,7 @@ public class ProjectConsumer {
             cacheProject(projectMessageDTO);
         } catch (Exception e) {
             log.error("项目报错。", e);
-            projectService.stopProject(projectId, projectType, e.getMessage());
+            projectService.stopProject(isChoiceGpu, projectId, projectType, e.getMessage());
             throw new RuntimeException(e);
         }
 
@@ -377,8 +378,7 @@ public class ProjectConsumer {
      */
     public void run(ProjectMessageDTO projectMessageDTO, String userId, String modelType, String clusterId, String projectRunningKey, String projectWaitingKey) {
         String projectId = projectMessageDTO.getProjectId();    // 项目 id
-        ProjectEntity projectEntity = projectUtil.getProjectByProjectId(projectId);
-        String isChoiceGpu = projectEntity.getIsChoiceGpu();
+        String isChoiceGpu = projectUtil.getIsChoiceGpuByProjectId(projectId);
         int parallelism = projectMessageDTO.getParallelism();  // 期望并行度
         //1 获取集群剩余可用并行度
         int restParallelism = projectUtil.getRestParallelism(isChoiceGpu);
@@ -419,7 +419,7 @@ public class ProjectConsumer {
         Map<String, Integer> nodeMap0 = projectUtil.getNodeMap(isChoiceGpu);
         Map<String, Integer> nodeMap = projectUtil.getNodeMapToUse(isChoiceGpu, Math.min(currentParallelism, taskTotal));
         //2 将指定 node 的并行度减少
-        nodeMap.keySet().forEach(nodeName -> projectUtil.decrementParallelismOfGpuNode(nodeName, nodeMap.get(nodeName)));
+        nodeMap.keySet().forEach(nodeName -> projectUtil.decrementParallelism(isChoiceGpu, nodeName, nodeMap.get(nodeName)));
         // 重新设置实际使用的并行度并保存到 redis
         int realCurrentParallelism = nodeMap.values().stream().mapToInt(parallelism -> parallelism).sum();
         projectMessageDTO.setCurrentParallelism(realCurrentParallelism);
@@ -519,7 +519,8 @@ public class ProjectConsumer {
         JsonNode jsonNode = new ObjectMapper().readTree(stopRecord.value());
         String projectId = jsonNode.path("projectId").asText();
         String projectType = jsonNode.path("type").asText();
-        projectService.stopProject(projectType, projectId);
+        String isChoiceGpu = projectUtil.getIsChoiceGpuByProjectId(projectId);
+        projectService.stopProject(isChoiceGpu, projectType, projectId);
     }
 
 

+ 4 - 4
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/service/ProjectService.java

@@ -278,7 +278,7 @@ public class ProjectService {
         return dockerImage;
     }
 
-    public void stopProject(String projectType, String projectId, String errorMessage) {
+    public void stopProject(String isChoiceGpu, String projectType, String projectId, String errorMessage) {
         Optional.ofNullable(errorMessage).ifPresent(em -> {
             if (DictConstants.PROJECT_TYPE_MANUAL.equals(projectType)) {
                 manualProjectMapper.saveErrorMessage(SchedulerProjectPO.builder().id(projectId).errorMessage(em).modifyUserId(DictConstants.SCHEDULER_USER_ID).modifyTime(TimeUtil.getNowForMysql()).build());
@@ -286,7 +286,7 @@ public class ProjectService {
                 autoSubProjectMapper.saveErrorMessage(SchedulerProjectPO.builder().id(projectId).errorMessage(em).modifyUserId(DictConstants.SCHEDULER_USER_ID).modifyTime(TimeUtil.getNowForMysql()).build());
             }
         });
-        stopProject(projectType, projectId);
+        stopProject(isChoiceGpu, projectType, projectId);
     }
 
     /**
@@ -294,7 +294,7 @@ public class ProjectService {
      * @param projectType 项目类型
      */
     @SneakyThrows
-    public void stopProject(String projectType, String projectId) {
+    public void stopProject(String isChoiceGpu, String projectType, String projectId) {
         // 将项目状态修改为终止中
         if (DictConstants.PROJECT_TYPE_MANUAL.equals(projectType)) {
             manualProjectMapper.updateProjectState(projectId, DictConstants.PROJECT_TERMINATING, TimeUtil.getNowForMysql());
@@ -323,7 +323,7 @@ public class ProjectService {
                 // 删除 pod
                 projectUtil.deletePod(podName);
                 // 节点并行度加一
-                projectUtil.incrementOneParallelismOfGpuNode(nodeName);
+                projectUtil.incrementOneParallelism(isChoiceGpu, nodeName);
             }
         }
 

+ 48 - 19
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/util/ProjectUtil.java

@@ -56,6 +56,10 @@ public class ProjectUtil {
     @Resource
     private CustomRedisClient customRedisClient;
 
+    public String getIsChoiceGpuByProjectId(String projectId) {
+        return getProjectByProjectId(projectId).getIsChoiceGpu();
+    }
+
 
     @SneakyThrows
     public void deleteYamlByProjectId(String projectId) {
@@ -133,6 +137,7 @@ public class ProjectUtil {
      */
     @SneakyThrows
     public void createNextPod(String userId, String projectId, String projectType, String nodeName, String lastPodName) {
+        final String isChoiceGpu = getIsChoiceGpuByProjectId(projectId);
         log.info("删除上一个 pod:projectId={},nodeName={},lastPodName={}", projectId, nodeName, lastPodName);
         String cpuOrderString = stringRedisTemplate.opsForValue().get("project:" + projectId + ":pod:" + lastPodName + ":cpu");
         deletePod(lastPodName);
@@ -142,7 +147,7 @@ public class ProjectUtil {
         if (CollectionUtil.isEmpty(yamlPathCacheKeySet)) {
             // 如果当前节点没有下一个yaml,则返回一个并行度。
             log.info("节点 " + nodeName + " 已经执行完被分配的项目 " + projectId + " 的所有 pod。");
-            incrementOneParallelismOfGpuNode(nodeName);
+            incrementOneParallelism(isChoiceGpu, nodeName);
             releaseLicense(userId, getModelTypeByProjectIdAndProjectType(projectId, projectType), 1);
         } else {
             final String yamlPathCacheKey = new ArrayList<>(yamlPathCacheKeySet).get(0);
@@ -549,29 +554,53 @@ public class ProjectUtil {
     }
 
 
-    public void incrementOneParallelismOfGpuNode(String nodeName) {
-        incrementParallelismOfGpuNode(nodeName, 1L);
+    public void incrementOneParallelism(String isChoiceGpu, String nodeName) {
+        incrementParallelism(isChoiceGpu, nodeName, 1L);
     }
 
-    public void incrementParallelismOfGpuNode(String nodeName, long number) {
-        //1 先检查缓存中的并行度是否超过,超过了就不加缓存的并行度了,常用于测试
-        String key = "gpu-node:" + nodeName + ":parallelism";
-        final int currentRestParallelism = Integer.parseInt(customRedisClient.get(key));
-        final List<NodeModel> nodeList = kubernetesConfiguration.getGpuNodeList();
-        nodeList.forEach(node -> {
-            if (nodeName.equals(node.getHostname())) {
-                if (currentRestParallelism + 1 < node.getParallelism()) {
-                    customRedisClient.increment(key, number);
+    public void incrementParallelism(String isChoiceGpu, String nodeName, long number) {
+        if (DictConstants.USE_GPU.equals(isChoiceGpu)) {
+            //1 先检查缓存中的并行度是否超过,超过了就不加缓存的并行度了,常用于测试
+            String key = "gpu-node:" + nodeName + ":parallelism";
+            final int currentRestParallelism = Integer.parseInt(customRedisClient.get(key));
+            final List<NodeModel> nodeList = kubernetesConfiguration.getGpuNodeList();
+            nodeList.forEach(node -> {
+                if (nodeName.equals(node.getHostname())) {
+                    if (currentRestParallelism + 1 < node.getParallelism()) {
+                        customRedisClient.increment(key, number);
+                    }
                 }
-            }
-        });
-        log.info("归还节点 {} 的 {} 个 GPU 并行度。", nodeName, number);
+            });
+            log.info("归还 GPU 节点 {} 的 {} 个并行度。", nodeName, number);
+        } else if (DictConstants.USE_CPU.equals(isChoiceGpu)) {
+            //1 先检查缓存中的并行度是否超过,超过了就不加缓存的并行度了,常用于测试
+            String key = "cpu-node:" + nodeName + ":parallelism";
+            final int currentRestParallelism = Integer.parseInt(customRedisClient.get(key));
+            final List<NodeModel> nodeList = kubernetesConfiguration.getCpuNodeList();
+            nodeList.forEach(node -> {
+                if (nodeName.equals(node.getHostname())) {
+                    if (currentRestParallelism + 1 < node.getParallelism()) {
+                        customRedisClient.increment(key, number);
+                    }
+                }
+            });
+            log.info("归还 CPU 节点 {} 的 {} 个并行度。", nodeName, number);
+        }
+
+
     }
 
-    public void decrementParallelismOfGpuNode(String nodeName, long number) {
-        String key = "gpu-node:" + nodeName + ":parallelism";
-        customRedisClient.decrement(key, number);
-        log.info("获取节点 {} 的 {} 个 GPU 并行度。", nodeName, number);
+    public void decrementParallelism(String isChoiceGpu, String nodeName, long number) {
+        if (DictConstants.USE_GPU.equals(isChoiceGpu)) {
+            String key = "gpu-node:" + nodeName + ":parallelism";
+            customRedisClient.decrement(key, number);
+            log.info("获取节点 {} 的 {} 个 GPU 并行度。", nodeName, number);
+        } else if (DictConstants.USE_CPU.equals(isChoiceGpu)) {
+            String key = "cpu-node:" + nodeName + ":parallelism";
+            customRedisClient.decrement(key, number);
+            log.info("获取节点 {} 的 {} 个 CPU 并行度。", nodeName, number);
+        }
+
     }
 
 

+ 3 - 3
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/util/TaskUtil.java

@@ -106,6 +106,7 @@ public class TaskUtil {
     public void isProjectCompleted(PrefixEntity redisPrefix, String userId, String projectId, String projectType, String maxSimulationTime, String taskId, String state, String podName) {
         boolean isCompleted;
         String nodeName = projectUtil.getNodeNameOfPod(projectId, podName);
+        final String isChoiceGpu = projectUtil.getIsChoiceGpuByProjectId(projectId);
         if (DictConstants.TASK_RUNNING.equals(state)) {  // 运行中的 pod 无需删除
             // 将运行中的任务的 pod 名称放入 redis
             stringRedisTemplate.opsForValue().set(redisPrefix.getTaskPodKey(), podName);
@@ -143,12 +144,11 @@ public class TaskUtil {
             } else if (DictConstants.TASK_ANALYSIS.equals(state)) { // 该状态只会获得一次
                 taskMapper.updateSuccessStateWithStopTime(taskId, state, TimeUtil.getNowForMysql());
                 // 查询项目是否使用 CPU 生成视频
-                String isChoiceGpu = projectUtil.getProjectByProjectId(projectId).getIsChoiceGpu();
                 if (DictConstants.VIDEO_CPU.equals(isChoiceGpu)) {
                     log.info("项目 {} 使用 CPU 生成视频。", projectId);
                     String generateVideoKey = "task:" + taskId + ":generateVideo";
                     customRedisClient.set(generateVideoKey, "0");
-                    HttpUtil.get(constantConfiguration.getGenerateVideoUrl().replace("simulation-resource-video", nodeName) + "?generateVideoKey=" + generateVideoKey + "&nodeName=" + nodeName + "&projectId" + projectId + "&projectType" + projectType + "&maxSimulationTime" + maxSimulationTime + "&taskId" + taskId);
+                    HttpUtil.get(constantConfiguration.getGenerateVideoUrl().replace("simulation-resource-video", nodeName) + "?generateVideoKey=" + generateVideoKey + "&nodeName=" + nodeName + "&projectId=" + projectId + "&projectType=" + projectType + "&maxSimulationTime=" + maxSimulationTime + "&taskId=" + taskId);
 //                    HttpUtil.get("http://" + nodeName + ":8007//simulation/resource/video/generate" + "?generateVideoKey=" + generateVideoKey + "&nodeName=" + nodeName + "&projectId" + projectId + "&projectType" + projectType + "&maxSimulationTime" + maxSimulationTime + "&taskId" + taskId);
 //                    videoFeignClient.generateVideo(generateVideoKey, nodeName, projectId, projectType, maxSimulationTime, taskId);
                     log.info("任务 {} 使用 CPU 生成视频中>>>>>>>", taskId);
@@ -167,7 +167,7 @@ public class TaskUtil {
             if (isCompleted) {
                 //如果项目已完成先把 pod 删除,并归还并行度
                 KubernetesUtil.deletePod2(apiClient, kubernetesConfiguration.getNamespace(), podName);
-                projectUtil.incrementOneParallelismOfGpuNode(nodeName);
+                projectUtil.incrementOneParallelism(isChoiceGpu, nodeName);
                 projectUtil.releaseLicense(userId, projectUtil.getModelTypeByProjectIdAndProjectType(projectId, projectType), 1);
             } else {
                 log.info("项目 " + projectId + " 还未运行完成。");