Przeglądaj źródła

将项目运行节点拆分成GPU节点和CPU节点

LingxinMeng 2 lat temu
rodzic
commit
0284848d55

+ 6 - 7
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/consumer/ProjectConsumer.java

@@ -376,11 +376,12 @@ public class ProjectConsumer {
      * @param projectWaitingKey projectWaitingKey
      */
     public void run(ProjectMessageDTO projectMessageDTO, String userId, String modelType, String clusterId, String projectRunningKey, String projectWaitingKey) {
-
-        String projectId = projectMessageDTO.getProjectId();
+        String projectId = projectMessageDTO.getProjectId();    // 项目 id
+        ProjectEntity projectEntity = projectUtil.getProjectByProjectId(projectId);
+        String isChoiceGpu = projectEntity.getIsChoiceGpu();
         int parallelism = projectMessageDTO.getParallelism();  // 期望并行度
         //1 获取集群剩余可用并行度
-        int restParallelism = projectUtil.getRestParallelism();
+        int restParallelism = projectUtil.getRestParallelism(isChoiceGpu);
         //2 判断剩余可用并行度是否大于项目并行度,否则加入扩充队列
         if (restParallelism > 0L) {
             log.info("集群 " + clusterId + " 执行项目 " + projectId);
@@ -389,7 +390,7 @@ public class ProjectConsumer {
             }
             // 设置实际的并行度
             projectMessageDTO.setCurrentParallelism(Math.min(restParallelism, parallelism));   // 设置实际的并行度
-            parseProject(projectMessageDTO, projectRunningKey);
+            parseProject(projectMessageDTO, projectRunningKey, isChoiceGpu);
         } else {
             log.info("服务器资源不够,项目 " + projectId + " 暂时加入等待队列。");
             wait(projectWaitingKey, projectMessageDTO);
@@ -401,12 +402,10 @@ public class ProjectConsumer {
      * @param projectRunningKey projectRunningKey
      */
     @SneakyThrows
-    public void parseProject(ProjectMessageDTO projectMessageDTO, String projectRunningKey) {
+    public void parseProject(ProjectMessageDTO projectMessageDTO, String projectRunningKey, String isChoiceGpu) {
         String projectId = projectMessageDTO.getProjectId();    // 项目 id
         String modelType = projectMessageDTO.getModelType();
         String vehicleConfigId = projectMessageDTO.getVehicleConfigId();
-        ProjectEntity projectEntity = projectUtil.getProjectByProjectId(projectId);
-        String isChoiceGpu = projectEntity.getIsChoiceGpu();
         int currentParallelism = projectMessageDTO.getCurrentParallelism();   // 当前并行度
         String algorithmId = projectMessageDTO.getAlgorithmId();    // 算法 id
         String projectPath = linuxTempPath + "project/" + projectId + "/";

+ 89 - 58
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/util/ProjectUtil.java

@@ -66,6 +66,7 @@ public class ProjectUtil {
         for (String absolutePath : absolutePathList) {
             if (absolutePath.contains(projectId)) {
                 boolean delete = new File(absolutePath).delete();
+                log.debug("删除结果:" + delete);
             }
         }
     }
@@ -205,31 +206,31 @@ public class ProjectUtil {
     }
 
 
-    /**
-     * 获取正在运行的项目的并行度总和
-     *
-     * @param clusterRunningPrefix 集群 key 前缀
-     * @return 正在运行的项目的并行度总和
-     */
-    @SneakyThrows
-    public int getCurrentParallelismSum(String clusterRunningPrefix) {
-        int result = 0;
-        Set<String> clusterRunningKeySet = stringRedisTemplate.keys(clusterRunningPrefix + "*");
-        List<String> runningProjectSet; // 运行中的 projectId 列表
-        if (CollectionUtil.isEmpty(clusterRunningKeySet)) {
-            return 0;
-        }
-        runningProjectSet = getRunningProjectList(clusterRunningKeySet);
-        if (CollectionUtil.isEmpty(runningProjectSet)) {
-            return 0;
-        }
-        for (String projectKey : runningProjectSet) {
-            String projectJsonTemp = stringRedisTemplate.opsForValue().get(projectKey);
-            ProjectMessageDTO projectMessageTemp = JsonUtil.jsonToBean(projectJsonTemp, ProjectMessageDTO.class);
-            result += projectMessageTemp.getCurrentParallelism();   // 获取当前正在使用的并行度
-        }
-        return result;
-    }
+//    /**
+//     * 获取正在运行的项目的并行度总和
+//     *
+//     * @param clusterRunningPrefix 集群 key 前缀
+//     * @return 正在运行的项目的并行度总和
+//     */
+//    @SneakyThrows
+//    public int getCurrentParallelismSum(String clusterRunningPrefix) {
+//        int result = 0;
+//        Set<String> clusterRunningKeySet = stringRedisTemplate.keys(clusterRunningPrefix + "*");
+//        List<String> runningProjectSet; // 运行中的 projectId 列表
+//        if (CollectionUtil.isEmpty(clusterRunningKeySet)) {
+//            return 0;
+//        }
+//        runningProjectSet = getRunningProjectList(clusterRunningKeySet);
+//        if (CollectionUtil.isEmpty(runningProjectSet)) {
+//            return 0;
+//        }
+//        for (String projectKey : runningProjectSet) {
+//            String projectJsonTemp = stringRedisTemplate.opsForValue().get(projectKey);
+//            ProjectMessageDTO projectMessageTemp = JsonUtil.jsonToBean(projectJsonTemp, ProjectMessageDTO.class);
+//            result += projectMessageTemp.getCurrentParallelism();   // 获取当前正在使用的并行度
+//        }
+//        return result;
+//    }
 
 
     /**
@@ -397,30 +398,60 @@ public class ProjectUtil {
      *
      * @return 集群剩余并行度
      */
-    public int getRestParallelism() {
-        List<NodeModel> initialNodeList = kubernetesConfiguration.getNodeList(); // 预设并行度的节点列表
-        // 遍历所有节点,获取还有剩余并行度的节点
-        List<NodeModel> restNodeList = new ArrayList<>();    // 剩余并行度的节点列表
-        for (NodeModel kubernetesNodeSource : initialNodeList) {
-            NodeModel kubernetesNodeCopy = kubernetesNodeSource.clone();
-            String nodeName = kubernetesNodeCopy.getHostname();   // 节点名称
-            int maxParallelism = kubernetesNodeCopy.getParallelism();
-            String restParallelismString = stringRedisTemplate.opsForValue().get("gpu-node:" + nodeName + ":parallelism");// 获取节点剩余并行度的 key
-            // -------------------------------- Comment --------------------------------
-            int restParallelism;
-            if (restParallelismString == null || Integer.parseInt(restParallelismString) > maxParallelism) {    // 如果剩余可用并行度没有值,说明是第一次查询,则重置成最大并行度的预设值
-                restParallelism = maxParallelism;
-                stringRedisTemplate.opsForValue().set("gpu-node:" + nodeName + ":parallelism", restParallelism + "");
-            } else {
-                restParallelism = Integer.parseInt(restParallelismString);
-                kubernetesNodeCopy.setParallelism(restParallelism);
+    public int getRestParallelism(String isChoiceGpu) {
+        List<NodeModel> initialNodeList; // 预设并行度的节点列表
+        if (DictConstants.USE_GPU.equals(isChoiceGpu)) {
+            initialNodeList = kubernetesConfiguration.getGpuNodeList(); // 预设并行度的节点列表
+            // 遍历所有节点,获取还有剩余并行度的节点
+            List<NodeModel> restNodeList = new ArrayList<>();    // 剩余并行度的节点列表
+            for (NodeModel kubernetesNodeSource : initialNodeList) {
+                NodeModel kubernetesNodeCopy = kubernetesNodeSource.clone();
+                String nodeName = kubernetesNodeCopy.getHostname();   // 节点名称
+                int maxParallelism = kubernetesNodeCopy.getParallelism();
+                String restParallelismString = stringRedisTemplate.opsForValue().get("gpu-node:" + nodeName + ":parallelism");// 获取节点剩余并行度的 key
+                // -------------------------------- Comment --------------------------------
+                int restParallelism;
+                if (restParallelismString == null || Integer.parseInt(restParallelismString) > maxParallelism) {    // 如果剩余可用并行度没有值,说明是第一次查询,则重置成最大并行度的预设值
+                    restParallelism = maxParallelism;
+                    stringRedisTemplate.opsForValue().set("gpu-node:" + nodeName + ":parallelism", String.valueOf(restParallelism));
+                } else {
+                    restParallelism = Integer.parseInt(restParallelismString);
+                    kubernetesNodeCopy.setParallelism(restParallelism);
+                }
+                if (restParallelism > 0) {
+                    restNodeList.add(kubernetesNodeCopy);
+                }
             }
-            if (restParallelism > 0) {
-                restNodeList.add(kubernetesNodeCopy);
+            log.info("集群剩余并行度为:" + restNodeList);
+            return restNodeList.size() == 0 ? 0 : restNodeList.stream().mapToInt(NodeModel::getParallelism).sum();
+        } else if (DictConstants.USE_CPU.equals(isChoiceGpu)) {
+            initialNodeList = kubernetesConfiguration.getCpuNodeList(); // 预设并行度的节点列表
+            // 遍历所有节点,获取还有剩余并行度的节点
+            List<NodeModel> restNodeList = new ArrayList<>();    // 剩余并行度的节点列表
+            for (NodeModel kubernetesNodeSource : initialNodeList) {
+                NodeModel kubernetesNodeCopy = kubernetesNodeSource.clone();
+                String nodeName = kubernetesNodeCopy.getHostname();   // 节点名称
+                int maxParallelism = kubernetesNodeCopy.getParallelism();
+                String restParallelismString = stringRedisTemplate.opsForValue().get("cpu-node:" + nodeName + ":parallelism");// 获取节点剩余并行度的 key
+                // -------------------------------- Comment --------------------------------
+                int restParallelism;
+                if (restParallelismString == null || Integer.parseInt(restParallelismString) > maxParallelism) {    // 如果剩余可用并行度没有值,说明是第一次查询,则重置成最大并行度的预设值
+                    restParallelism = maxParallelism;
+                    stringRedisTemplate.opsForValue().set("cpu-node:" + nodeName + ":parallelism", String.valueOf(restParallelism));
+                } else {
+                    restParallelism = Integer.parseInt(restParallelismString);
+                    kubernetesNodeCopy.setParallelism(restParallelism);
+                }
+                if (restParallelism > 0) {
+                    restNodeList.add(kubernetesNodeCopy);
+                }
             }
+            log.info("集群剩余并行度为:" + restNodeList);
+            return restNodeList.size() == 0 ? 0 : restNodeList.stream().mapToInt(NodeModel::getParallelism).sum();
+        } else {
+            throw new RuntimeException("未知是否使用 GPU:" + isChoiceGpu);
         }
-        log.info(" 集群剩余并行度为:" + restNodeList);
-        return restNodeList.size() == 0 ? 0 : restNodeList.stream().mapToInt(NodeModel::getParallelism).sum();
+
     }
 
 
@@ -470,15 +501,15 @@ public class ProjectUtil {
     }
 
 
-    /**
-     * 获取 projectId 列表
-     *
-     * @param clusterRunningKeySet 集群下的所有键值对(包括运行中的项目和等待中的项目)
-     * @return projectId 列表
-     */
-    public List<String> getRunningProjectList(Set<String> clusterRunningKeySet) {
-        return clusterRunningKeySet.stream().filter(key -> StringUtil.countSubString(key, ":") == 3).collect(Collectors.toList());
-    }
+//    /**
+//     * 获取 projectId 列表
+//     *
+//     * @param clusterRunningKeySet 集群下的所有键值对(包括运行中的项目和等待中的项目)
+//     * @return projectId 列表
+//     */
+//    public List<String> getRunningProjectList(Set<String> clusterRunningKeySet) {
+//        return clusterRunningKeySet.stream().filter(key -> StringUtil.countSubString(key, ":") == 3).collect(Collectors.toList());
+//    }
 
 
     public PrefixEntity getRedisPrefixByProjectIdAndProjectType(String projectId, String projectType) {
@@ -529,7 +560,7 @@ public class ProjectUtil {
         //1 先检查缓存中的并行度是否超过,超过了就不加缓存的并行度了,常用于测试
         String key = "gpu-node:" + nodeName + ":parallelism";
         final int currentRestParallelism = Integer.parseInt(customRedisClient.get(key));
-        final List<NodeModel> nodeList = kubernetesConfiguration.getNodeList();
+        final List<NodeModel> nodeList = kubernetesConfiguration.getGpuNodeList();
         nodeList.forEach(node -> {
             if (nodeName.equals(node.getHostname())) {
                 if (currentRestParallelism + 1 < node.getParallelism()) {
@@ -557,8 +588,8 @@ public class ProjectUtil {
     }
 
     public void resetNodeParallelism() {
-        kubernetesConfiguration.getNodeList().forEach((node) -> customRedisClient.set("gpu-node:" + node.getHostname() + ":parallelism", node.getParallelism() + ""));
-        esminiConfiguration.getNodeList().forEach((node) -> customRedisClient.set("cpu-node:" + node.getHostname() + ":parallelism", node.getParallelism() + ""));
+        kubernetesConfiguration.getGpuNodeList().forEach((node) -> customRedisClient.set("gpu-node:" + node.getHostname() + ":parallelism", String.valueOf(node.getParallelism())));
+        kubernetesConfiguration.getCpuNodeList().forEach((node) -> customRedisClient.set("cpu-node:" + node.getHostname() + ":parallelism", String.valueOf(node.getParallelism())));
     }
 
     /**