فهرست منبع

并行度不超过任务数量

root 2 سال پیش
والد
کامیت
9c44096945

+ 29 - 23
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/consumer/ProjectConsumer.java

@@ -3,7 +3,6 @@ package com.css.simulation.resource.scheduler.consumer;
 
 import api.common.pojo.constants.DictConstants;
 import api.common.pojo.dto.ProjectMessageDTO;
-import api.common.util.CollectionUtil;
 import api.common.util.JsonUtil;
 import api.common.util.StringUtil;
 import com.css.simulation.resource.scheduler.mapper.*;
@@ -22,10 +21,7 @@ import org.springframework.kafka.annotation.KafkaListener;
 import org.springframework.stereotype.Component;
 
 import javax.annotation.Resource;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
+import java.util.*;
 
 @Component
 @Slf4j
@@ -66,8 +62,6 @@ public class ProjectConsumer {
     ProjectUtil projectUtil;
 
 
-
-
     /**
      * 任务运行前首先判断用户是否拥有可分配资源
      *
@@ -149,19 +143,14 @@ public class ProjectConsumer {
 
         String projectId = projectMessageDTO.getProjectId();
         int parallelism = projectMessageDTO.getParallelism();  // 期望并行度
-        int parallelismSum; //实际可用并行度
-        //1 获取所有节点的剩余可用并行度
-        Map<String, Integer> nodeMap = projectUtil.getNodeMapToUse(parallelism);
-        if (CollectionUtil.isEmpty(nodeMap)) {
-            parallelismSum = 0;
-        } else {
-            parallelismSum = nodeMap.keySet().stream().mapToInt(nodeMap::get).sum();
-        }
+        //1 获取集群剩余可用并行度
+        int restParallelism = projectUtil.getRestParallelism();
         //2 判断剩余可用并行度是否大于项目并行度,否则加入扩充队列
-        if (parallelismSum > 0L) {
-            log.info("ProjectConsumer--run 集群 " + clusterId + " 将项目 " + projectId + "在节点 " + nodeMap + " 上执行!");
-            projectMessageDTO.setCurrentParallelism(parallelismSum);    // 设置实际的并行度
-            parseProject(nodeMap, projectMessageDTO, projectWaitingKey, projectRunningKey);
+        if (restParallelism > 0L) {
+            log.info("ProjectConsumer--run 集群 " + clusterId + " 执行项目 " + projectId);
+            // 设置实际的并行度
+            projectMessageDTO.setCurrentParallelism(Math.min(restParallelism, parallelism));   // 设置实际的并行度
+            parseProject(projectMessageDTO, projectWaitingKey, projectRunningKey);
         } else {
             log.info("ProjectConsumer--cacheManualProject 服务器资源不够,项目 " + projectId + " 暂时加入等待队列。");
             wait(projectWaitingKey, projectMessageDTO);
@@ -179,21 +168,21 @@ public class ProjectConsumer {
 
 
     /**
-     * @param nodeMap           节点列表以及剩余可用并行度
      * @param projectMessageDTO 初始接收到的项目启动信息
      * @param projectWaitingKey projectWaitingKey
      * @param projectRunningKey projectRunningKey
      */
     @SneakyThrows
-    public void parseProject(Map<String, Integer> nodeMap, ProjectMessageDTO projectMessageDTO, String projectWaitingKey, String projectRunningKey) {
+    public void parseProject(ProjectMessageDTO projectMessageDTO, String projectWaitingKey, String projectRunningKey) {
         String projectId = projectMessageDTO.getProjectId();    // 项目 id
         String projectType = projectMessageDTO.getType();   // 项目类型
+        int currentParallelism = projectMessageDTO.getCurrentParallelism();   // 当前并行度
         String packageId = projectMessageDTO.getScenePackageId();   // 场景测试包 id
         long videoTime = projectMessageDTO.getMaxSimulationTime(); // 结果视频的时长
         String vehicleConfigId = projectMessageDTO.getVehicleConfigId();// 模型配置 id
         String algorithmId = projectMessageDTO.getAlgorithmId();    // 算法 id
         // -------------------------------- 0 准备 --------------------------------
-        projectService.prepare(nodeMap, projectMessageDTO, projectWaitingKey, projectRunningKey);
+        projectService.prepare(projectMessageDTO, projectWaitingKey, projectRunningKey);
         String userId = null;
         if (DictConstants.PROJECT_TYPE_MANUAL.equals(projectType)) {
             userId = manualProjectMapper.selectCreateUserById(projectId);
@@ -206,8 +195,25 @@ public class ProjectConsumer {
         int taskTotal = scenePOList.size();
         projectMessageDTO.setTaskTotal(taskTotal);
         projectMessageDTO.setTaskCompleted(0);
-        // 设置任务数量之后将项目运行信息放入 redis
+        // 设置任务数量之后,获取运行节点,并将项目运行信息放入 redis
+        Map<String, Integer> nodeMap;
+        if (currentParallelism < taskTotal) {
+            nodeMap = projectUtil.getNodeMapToUse(currentParallelism);
+        } else {
+            nodeMap = projectUtil.getNodeMapToUse(taskTotal);
+        }
+        // 将指定 node 的并行度减少
+        nodeMap.keySet().forEach(nodeName -> {
+            int parallelismToUse = nodeMap.get(nodeName);
+            String restParallelismKey = "node:" + nodeName + ":parallelism";
+            int restParallelism = Integer.parseInt(Objects.requireNonNull(stringRedisTemplate.opsForValue().get(restParallelismKey)));// 剩余可用并行度
+            stringRedisTemplate.opsForValue().set(restParallelismKey, (restParallelism - parallelismToUse) + "");
+        });
+        // 重新设置实际使用的并行度并保存到 redis
+        projectMessageDTO.setCurrentParallelism(nodeMap.values().stream().mapToInt(parallelism -> parallelism).sum());
+        log.info("ProjectConsume--parseProject 项目 " + projectId + " 运行在:" + nodeMap);
         stringRedisTemplate.opsForValue().set(projectRunningKey, JsonUtil.beanToJson(projectMessageDTO));
+
         Set<ScenePO> scenePOSet = new HashSet<>(scenePOList); // 如果不去重的话会出现多个场景重复关联多个指标
         // -------------------------------- 2 查询模型 --------------------------------
         //2-1 根据车辆配置id vehicleConfigId, 获取 模型信息和传感器信息

+ 7 - 11
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/scheduler/ProjectScheduler.java

@@ -28,7 +28,6 @@ import org.springframework.transaction.annotation.Transactional;
 import javax.annotation.Resource;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Map;
 import java.util.Set;
 
 @Component
@@ -165,19 +164,16 @@ public class ProjectScheduler {
             return;
         }
         //1 获取所有节点的剩余可用并行度
-        Map<String, Integer> nodeMap = projectUtil.getNodeMapToUse(parallelism);
-        if (CollectionUtil.isEmpty(nodeMap)) {
+        int restParallelism = projectUtil.getRestParallelism();
+        if (restParallelism == 0) {
             log.info("ProjectScheduler--run 集群中没有可用并行度,项目 " + projectId + " 继续等待。");
             return;
         }
-        //2 计算实际可用并行度
-        int parallelismSum = nodeMap.keySet().stream().mapToInt(nodeMap::get).sum();
-
-        //2 判断剩余可用并行度是否大于项目并行度,否则加入扩充队列
-        if (parallelismSum > 0L) {
-            log.info("ProjectScheduler--run 集群 " + clusterId + " 将项目 " + projectId + "在节点 " + nodeMap + " 上执行。");
-            projectMessageDTO.setCurrentParallelism(parallelismSum);    // 设置实际的并行度
-            projectConsumer.parseProject(nodeMap, projectMessageDTO, projectWaitingKey, projectRunningKey);
+        //2 判断剩余可用并行度是否大于项目并行度,否则继续等待
+        if (restParallelism > 0L) {
+            log.info("ProjectScheduler--run 集群 " + clusterId + " 执行项目项目 " + projectId);
+            projectMessageDTO.setCurrentParallelism(restParallelism);    // 设置实际的并行度
+            projectConsumer.parseProject(projectMessageDTO, projectWaitingKey, projectRunningKey);
         }
     }
 

+ 6 - 10
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/service/ProjectService.java

@@ -103,21 +103,14 @@ public class ProjectService {
     // -------------------------------- Comment --------------------------------
 
     /**
-     * @param nodeMap           节点列表以及剩余可用并行度
      * @param projectMessageDTO 初始接收到的项目启动信息
      * @param projectWaitingKey projectWaitingKey
      * @param projectRunningKey projectRunningKey
      */
     @Transactional
-    public void prepare(Map<String, Integer> nodeMap, ProjectMessageDTO projectMessageDTO, String projectWaitingKey, String projectRunningKey) {
+    public void prepare(ProjectMessageDTO projectMessageDTO, String projectWaitingKey, String projectRunningKey) {
         String projectId = projectMessageDTO.getProjectId();
-        //1 将指定 node 的并行度减少
-        nodeMap.keySet().forEach(nodeName -> {
-            int parallelismToUse = nodeMap.get(nodeName);
-            String restParallelismKey = "node:" + nodeName + ":parallelism";
-            int restParallelism = Integer.parseInt(Objects.requireNonNull(stringRedisTemplate.opsForValue().get(restParallelismKey)));// 剩余可用并行度
-            stringRedisTemplate.opsForValue().set(restParallelismKey, (restParallelism - parallelismToUse) + "");
-        });
+
         //2 将 redis 中该项目旧的信息则直接删除(包括 waitingKey)
         RedisUtil.deleteByPrefix(stringRedisTemplate, projectWaitingKey);
         RedisUtil.deleteByPrefix(stringRedisTemplate, projectRunningKey);
@@ -181,7 +174,9 @@ public class ProjectService {
      * @param cameraPOList
      * @param ogtPOList
      */
-    public void sendTaskMessage(String projectRunningPrefix, String userId, String projectId, String projectType, Long videoTime, Set<ScenePO> scenePOSet, VehiclePO vehiclePO, List<CameraPO> cameraPOList, List<OgtPO> ogtPOList) {
+    @SneakyThrows
+    public void sendTaskMessage( String projectRunningPrefix, String userId, String projectId, String projectType, Long videoTime, Set<ScenePO> scenePOSet, VehiclePO vehiclePO, List<CameraPO> cameraPOList, List<OgtPO> ogtPOList) {
+
         final int[] messageNumber = {0};
         log.info("ProjectService--sendTaskMessage 项目 " + projectId + " 获得的包括的场景信息为:" + scenePOSet);
         for (ScenePO scenePO : scenePOSet) {
@@ -288,6 +283,7 @@ public class ProjectService {
             });
         }
         log.info("ProjectService--sendTaskMessage 共发送了 " + messageNumber[0] + " 条消息!");
+
     }
 
 

+ 31 - 0
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/util/ProjectUtil.java

@@ -198,6 +198,37 @@ public class ProjectUtil {
     }
 
 
+    /**
+     * 获取集群剩余并行度
+     *
+     * @return 集群剩余并行度
+     */
+    public int getRestParallelism() {
+        List<KubernetesNodeTO> initialNodeList = kubernetesConfiguration.getNodeList(); // 预设并行度的节点列表
+        // 遍历所有节点,获取还有剩余并行度的节点
+        List<KubernetesNodeTO> restNodeList = new ArrayList<>();    // 剩余并行度的节点列表
+        for (KubernetesNodeTO kubernetesNodeSource : initialNodeList) {
+            KubernetesNodeTO kubernetesNodeCopy = kubernetesNodeSource.clone();
+            String nodeName = kubernetesNodeCopy.getName();   // 节点名称
+            int maxParallelism = kubernetesNodeCopy.getMaxParallelism();
+            String restParallelismString = stringRedisTemplate.opsForValue().get("node:" + nodeName + ":parallelism");// 获取节点剩余并行度的 key
+            // -------------------------------- Comment --------------------------------
+            int restParallelism;
+            if (restParallelismString == null) {    // 如果剩余可用并行度没有值,说明是第一次查询,则重置成最大并行度的预设值
+                restParallelism = maxParallelism;
+                stringRedisTemplate.opsForValue().set("node:" + nodeName + ":parallelism", restParallelism + "");
+            } else {
+                restParallelism = Integer.parseInt(restParallelismString);
+                kubernetesNodeCopy.setMaxParallelism(restParallelism);
+            }
+            if (restParallelism > 0) {
+                restNodeList.add(kubernetesNodeCopy);
+            }
+        }
+        log.info("ProjectUtil--getRestParallelism 集群剩余并行度为:" + restNodeList);
+        return restNodeList.size() == 0 ? 0 : restNodeList.stream().mapToInt(KubernetesNodeTO::getMaxParallelism).sum();
+    }
+
     /**
      * 根据并行度获取用于执行的节点列表
      * 根据剩余可用并行度降序排序