Prechádzať zdrojové kódy

根据模型类型判断证书占用

LingxinMeng 2 rokov pred
rodič
commit
7cddbae4d5

+ 3 - 0
api-common/src/main/java/api/common/pojo/constants/DictConstants.java

@@ -2,6 +2,8 @@ package api.common.pojo.constants;
 
 public class DictConstants {
 
+    public static final String MODEL_TYPE_VTD = "1";
+    public static final String MODEL_TYPE_CARSIM = "2";
     public static final String SCHEDULER_USER_ID = "simulation-resource-scheduler";
 
     public static final String SENSOR_CAMERA = "camera"; // 摄像头
@@ -137,6 +139,7 @@ public class DictConstants {
 
     // 集群id
     public static final String SYSTEM_CLUSTER_ID = "system"; // 超管使用此集群id执行项目
+    public static final String SYSTEM_USER_ID = "admin"; // 超管使用此集群id执行项目
 
     // 评测等级
     public static final String EVALUATION_LEVEL_G = "G";

+ 41 - 31
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/consumer/ProjectConsumer.java

@@ -133,14 +133,14 @@ public class ProjectConsumer {
             String algorithmDockerImage = projectService.handleAlgorithm(projectId, algorithmId);
             log.info("项目 " + projectId + " 算法已导入:" + algorithmDockerImage);
             // -------------------------------- 3 查询模型 --------------------------------
-            if ("1".equals(modelType)) {
-                log.info("项目 " + projectId + " 开始查询模型。");
+            if (DictConstants.MODEL_TYPE_VTD.equals(modelType)) {
+                log.debug("项目 " + projectId + " 开始查询模型。");
                 //2-1 根据车辆配置id vehicleConfigId, 获取 模型信息和传感器信息
                 VehicleEntity vehicleEntity = vehicleMapper.selectByVehicleConfigId(vehicleConfigId);   // 车辆
                 List<CameraEntity> cameraEntityList = sensorCameraMapper.selectCameraByVehicleConfigId(vehicleConfigId);    // 摄像头
                 List<OgtEntity> ogtEntityList = sensorOgtMapper.selectOgtByVehicleId(vehicleConfigId); // 完美传感器
                 // -------------------------------- 4 保存任务消息 --------------------------------
-                log.info("项目 " + projectId + " 开始保存任务消息。");
+                log.debug("项目 " + projectId + " 开始保存任务消息。");
                 List<TaskEntity> taskList = new ArrayList<>();
                 for (SceneEntity sceneEntity : sceneEntitySet) {
                     String sceneId = sceneEntity.getId();
@@ -207,14 +207,13 @@ public class ProjectConsumer {
                 }
                 taskUtil.batchInsertTask(taskList);
                 log.info("项目 " + projectId + " 共有 " + taskList.size() + " 个任务,已保存到数据库");
-            } else if ("2".equals(modelType)) {
-                log.info("项目 " + projectId + " 开始查询模型。");
-
+            } else if (DictConstants.MODEL_TYPE_CARSIM.equals(modelType)) {
+                log.debug("项目 " + projectId + " 开始查询模型。");
                 VehicleEntity vehicleEntity = vehicleMapper.selectByVehicleConfigId(vehicleConfigId);   // 车辆
                 List<CameraEntity> cameraEntityList = sensorCameraMapper.selectCameraByVehicleConfigId(vehicleConfigId);    // 摄像头
                 List<OgtEntity> ogtEntityList = sensorOgtMapper.selectOgtByVehicleId(vehicleConfigId); // 完美传感器
                 // -------------------------------- 4 保存任务消息 --------------------------------
-                log.info("项目 " + projectId + " 开始保存任务消息。");
+                log.debug("项目 " + projectId + " 开始保存任务消息。");
                 List<TaskEntity> taskList = new ArrayList<>();
                 for (SceneEntity sceneEntity : sceneEntitySet) {
                     String sceneId = sceneEntity.getId();
@@ -317,12 +316,7 @@ public class ProjectConsumer {
         } else if (DictConstants.PROJECT_TYPE_AUTO_SUB.equals(projectType)) {
             userId = autoSubProjectMapper.selectCreateUserById(projectId);
         } else {
-            log.error("项目类型错误:" + projectMessageDTO);
-            return;
-        }
-        if (StringUtil.isEmpty(userId)) {
-            log.error("未查询到项目创建人:" + projectMessageDTO);
-            return;
+            throw new RuntimeException("未知项目类型:" + projectType);
         }
         //3 获取用户类型(管理员账户、管理员子账户、普通账户、普通子账户)(独占、共享)
         UserEntity userEntity = userMapper.selectById(userId);
@@ -333,7 +327,7 @@ public class ProjectConsumer {
         if (DictConstants.ROLE_CODE_SYSADMIN.equals(roleCode) || DictConstants.ROLE_CODE_ADMIN.equals(roleCode)) {  //3-1 管理员账户和管理员子账户直接执行
             log.info("项目 " + projectId + " 的创建人 " + userId + " 为管理员账户或管理员子账户,直接判断服务器能否执行。");
             PrefixEntity redisPrefix = projectUtil.getRedisPrefixByClusterIdAndProjectId(DictConstants.SYSTEM_CLUSTER_ID, projectId);
-            run(projectMessageDTO, DictConstants.SYSTEM_CLUSTER_ID, redisPrefix.getProjectRunningKey(), redisPrefix.getProjectWaitingKey());
+            run(projectMessageDTO, DictConstants.SYSTEM_CLUSTER_ID, modelType, DictConstants.SYSTEM_USER_ID, redisPrefix.getProjectRunningKey(), redisPrefix.getProjectWaitingKey());
             return;
         } else if (DictConstants.ROLE_CODE_UESR.equals(roleCode)) { //3-2 普通账户,不管是独占还是共享,都在自己的集群里排队,根据自己的独占节点排队
             clusterEntity = clusterMapper.selectByUserId(userId);
@@ -348,22 +342,34 @@ public class ProjectConsumer {
                 log.info("项目 " + projectId + " 的创建人 " + userId + " 为普通共享子账户(父账户的集群),集群为:" + clusterEntity);
             }
         } else {
-            log.error("项目 " + projectId + " 的创建人 " + userId + " 为未知账户类型,不予执行!");
-            return;
+            throw new RuntimeException("未知角色类型:" + roleCode);
         }
-        // 获取拥有的节点数量,即仿真软件证书数量
-        String clusterId = clusterEntity.getId();
-        int simulationLicenseNumber = clusterEntity.getNumSimulationLicense();
-        // 获取该集群中正在运行的项目,如果没有则立即执行
-        PrefixEntity redisPrefix = projectUtil.getRedisPrefixByClusterIdAndProjectId(clusterId, projectId);
-        // 获取正在运行的项目的并行度总和
-        int currentParallelismSum = projectUtil.getCurrentParallelismSum(redisPrefix.getClusterRunningPrefix());
-        // 如果执行后的并行度总和小于最大节点数则执行,否则不执行
-        if (currentParallelismSum + parallelism <= simulationLicenseNumber) {
-            run(projectMessageDTO, clusterId, redisPrefix.getProjectRunningKey(), redisPrefix.getProjectWaitingKey());
+        PrefixEntity redisPrefix = projectUtil.getRedisPrefixByClusterIdAndProjectId(clusterEntity.getId(), projectId);
+        if (DictConstants.MODEL_TYPE_VTD.equals(modelType)) {
+            // 获取仿真软件证书数量和动力学软件证书数量(vtd占一个仿真证书,carsim各占一个)
+//            // 获取正在运行的项目的并行度总和
+//            int currentParallelismSum = projectUtil.getCurrentParallelismSum(redisPrefix.getClusterRunningPrefix());
+            // 如果执行后的并行度总和小于最大节点数则执行,否则不执行
+            if (projectUtil.getUsingLicenseNumber(userId, DictConstants.MODEL_TYPE_VTD) + parallelism <= clusterEntity.getNumSimulationLicense()) {
+                run(projectMessageDTO, userId, modelType, clusterEntity.getId(), redisPrefix.getProjectRunningKey(), redisPrefix.getProjectWaitingKey());
+            } else {
+                log.info("项目 " + projectId + " 并行度超出账户允许,加入等待队列,暂不执行。 ");
+                wait(redisPrefix.getProjectWaitingKey(), projectMessageDTO);
+            }
+        } else if (DictConstants.MODEL_TYPE_CARSIM.equals(modelType)) {
+            // 获取仿真软件证书数量和动力学软件证书数量(vtd占一个仿真证书,carsim各占一个)
+//                // 获取正在运行的项目的并行度总和
+//                int currentParallelismSum = projectUtil.getCurrentParallelismSum(redisPrefix.getClusterRunningPrefix());
+            // 如果执行后的并行度总和小于最大节点数则执行,否则不执行
+            if (projectUtil.getUsingLicenseNumber(userId, DictConstants.MODEL_TYPE_VTD) + parallelism <= clusterEntity.getNumSimulationLicense()
+                    && projectUtil.getUsingLicenseNumber(userId, DictConstants.MODEL_TYPE_CARSIM) + parallelism <= clusterEntity.getNumDynamicLicense()) {
+                run(projectMessageDTO, userId, modelType, clusterEntity.getId(), redisPrefix.getProjectRunningKey(), redisPrefix.getProjectWaitingKey());
+            } else {
+                log.info("项目 " + projectId + " 并行度超出账户允许,加入等待队列,暂不执行。 ");
+                wait(redisPrefix.getProjectWaitingKey(), projectMessageDTO);
+            }
         } else {
-            log.info("项目 " + projectId + " 并行度超出账户允许,加入等待队列,暂不执行。 ");
-            wait(redisPrefix.getProjectWaitingKey(), projectMessageDTO);
+            throw new RuntimeException("未知模型类型:" + modelType);
         }
     }
 
@@ -371,11 +377,12 @@ public class ProjectConsumer {
 
     /**
      * @param projectMessageDTO 初始接收到的项目启动信息
-     * @param clusterId         集群 id
+     * @param userId            用户ID
+     * @param clusterId         集群ID
      * @param projectRunningKey projectRunningKey
      * @param projectWaitingKey projectWaitingKey
      */
-    public void run(ProjectMessageDTO projectMessageDTO, String clusterId, String projectRunningKey, String projectWaitingKey) {
+    public void run(ProjectMessageDTO projectMessageDTO, String userId, String modelType, String clusterId, String projectRunningKey, String projectWaitingKey) {
 
         String projectId = projectMessageDTO.getProjectId();
         int parallelism = projectMessageDTO.getParallelism();  // 期望并行度
@@ -384,6 +391,9 @@ public class ProjectConsumer {
         //2 判断剩余可用并行度是否大于项目并行度,否则加入扩充队列
         if (restParallelism > 0L) {
             log.info("集群 " + clusterId + " 执行项目 " + projectId);
+            if (!DictConstants.SYSTEM_USER_ID.equals(userId)) {
+                projectUtil.useLicense(userId, modelType, parallelism);
+            }
             // 设置实际的并行度
             projectMessageDTO.setCurrentParallelism(Math.min(restParallelism, parallelism));   // 设置实际的并行度
             parseProject(projectMessageDTO, projectRunningKey);
@@ -408,7 +418,7 @@ public class ProjectConsumer {
         String algorithmId = projectMessageDTO.getAlgorithmId();    // 算法 id
         String projectPath = linuxTempPath + "project/" + projectId + "/";
         // -------------------------------- 1 获取任务 json 列表 --------------------------------
-        List<String> taskJsonList = FileUtil.listAbsolutePathByTypeAndLength(projectPath, "json", (StringUtil.getRandomUUID()+".json").length());
+        List<String> taskJsonList = FileUtil.listAbsolutePathByTypeAndLength(projectPath, "json", (StringUtil.getRandomUUID() + ".json").length());
         int taskTotal = taskJsonList.size();
         projectMessageDTO.setTaskTotal(taskTotal);
         projectMessageDTO.setTaskCompleted(0);

+ 8 - 0
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/mapper/AutoSubProjectMapper.java

@@ -48,4 +48,12 @@ public interface AutoSubProjectMapper {
             "set error_message = #{errorMessage}\n" +
             "where project_id = #{id}\n")
     void saveErrorMessage(SchedulerProjectPO schedulerProjectPO);
+
+    @Select("select parameter_type\n" +
+            "from model_vehicle t1\n" +
+            "         left join model_config t2 on t1.id = t2.vehicle_id\n" +
+            "         left join simulation_automatic_project t3 on t2.id = t3.vehicle\n" +
+            "         left join simulation_automatic_subproject t4 on t3.id = t4.parent_id\n" +
+            "where t4.id = #{projectId}\n")
+    String selectModelTypeByProjectId(String projectId);
 }

+ 7 - 0
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/mapper/ManualProjectMapper.java

@@ -46,4 +46,11 @@ public interface ManualProjectMapper {
             "set error_message = #{errorMessage}\n" +
             "where project_id = #{id}\n")
     void saveErrorMessage(SchedulerProjectPO schedulerProjectPO);
+
+    @Select("select parameter_type\n" +
+            "from model_vehicle t1\n" +
+            "         left join model_config t2 on t1.id = t2.vehicle_id\n" +
+            "         left join simulation_manual_project t3 on t2.id = t3.vehicle\n" +
+            "where t3.id = #{projectId}")
+    String selectModelTypeByProjectId(String projectId);
 }

+ 69 - 2
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/util/ProjectUtil.java

@@ -6,7 +6,10 @@ import api.common.util.*;
 import com.css.simulation.resource.scheduler.configuration.esmini.EsminiConfiguration;
 import com.css.simulation.resource.scheduler.configuration.kubernetes.KubernetesConfiguration;
 import com.css.simulation.resource.scheduler.configuration.redis.CustomRedisClient;
-import com.css.simulation.resource.scheduler.data.entity.*;
+import com.css.simulation.resource.scheduler.data.entity.NodeEntity;
+import com.css.simulation.resource.scheduler.data.entity.PrefixEntity;
+import com.css.simulation.resource.scheduler.data.entity.ProjectEntity;
+import com.css.simulation.resource.scheduler.data.entity.UserEntity;
 import com.css.simulation.resource.scheduler.data.model.NodeModel;
 import com.css.simulation.resource.scheduler.mapper.AutoSubProjectMapper;
 import com.css.simulation.resource.scheduler.mapper.ClusterMapper;
@@ -131,7 +134,7 @@ public class ProjectUtil {
      * @param lastPodName 即将删除的 pod 名称
      */
     @SneakyThrows
-    public void createNextPod(String projectId, String nodeName, String lastPodName) {
+    public void createNextPod(String userId, String projectId, String projectType, String nodeName, String lastPodName) {
         log.info("删除上一个 pod:projectId={},nodeName={},lastPodName={}", projectId, nodeName, lastPodName);
         String cpuOrderString = stringRedisTemplate.opsForValue().get("project:" + projectId + ":pod:" + lastPodName + ":cpu");
         deletePod(lastPodName);
@@ -142,6 +145,7 @@ public class ProjectUtil {
             // 如果当前节点没有下一个yaml,则返回一个并行度。
             log.info("节点 " + nodeName + " 已经执行完被分配的项目 " + projectId + " 的所有 pod。");
             incrementOneParallelismOfGpuNode(nodeName);
+            releaseLicense(userId, getModelTypeByProjectIdAndProjectType(projectId, projectType), 1);
         } else {
             final String yamlPathCacheKey = new ArrayList<>(yamlPathCacheKeySet).get(0);
             final String absolutePath = stringRedisTemplate.opsForValue().get(yamlPathCacheKey);
@@ -500,4 +504,67 @@ public class ProjectUtil {
         }
         return result;
     }
+
+    public Integer getUsingLicenseNumber(String userId, String modelType) {
+        String key;
+        if (DictConstants.MODEL_TYPE_VTD.equals(modelType)) {
+            key = "user:" + userId + ":using-license:vtd";
+        } else if (DictConstants.MODEL_TYPE_CARSIM.equals(modelType)) {
+            key = "user:" + userId + ":using-license:carsim";
+        } else {
+            throw new RuntimeException("未知模型类型:" + modelType);
+        }
+        final String usingLicense = customRedisClient.get(key);
+        if (StringUtil.isEmpty(usingLicense)) {
+            customRedisClient.set(key, "0");
+            return 0;
+        } else {
+            return Integer.parseInt(usingLicense);
+        }
+    }
+
+    public void useLicense(String userId, String modelType, int parallelism) {
+        String key;
+        if (DictConstants.MODEL_TYPE_VTD.equals(modelType)) {
+            key = "user:" + userId + ":using-license:vtd";
+        } else if (DictConstants.MODEL_TYPE_CARSIM.equals(modelType)) {
+            key = "user:" + userId + ":using-license:carsim";
+        } else {
+            throw new RuntimeException("未知模型类型:" + modelType);
+        }
+        final String usingLicense = customRedisClient.get(key);
+        if (StringUtil.isEmpty(usingLicense)) {
+            customRedisClient.set(key, String.valueOf(parallelism));
+        } else {
+            customRedisClient.increment(key, parallelism);
+        }
+    }
+
+    public void releaseLicense(String userId, String modelType, int parallelism) {
+        String key;
+        if (DictConstants.MODEL_TYPE_VTD.equals(modelType)) {
+            key = "user:" + userId + ":using-license:vtd";
+        } else if (DictConstants.MODEL_TYPE_CARSIM.equals(modelType)) {
+            key = "user:" + userId + ":using-license:carsim";
+        } else {
+            throw new RuntimeException("未知模型类型:" + modelType);
+        }
+        final String usingLicense = customRedisClient.get(key);
+        if (StringUtil.isEmpty(usingLicense)) {
+            customRedisClient.set(key, "0");
+        } else {
+            customRedisClient.decrement(key, parallelism);
+        }
+    }
+
+
+    public String getModelTypeByProjectIdAndProjectType(String projectId, String projectType) {
+        if (DictConstants.PROJECT_TYPE_MANUAL.equals(projectType)) {
+            return manualProjectMapper.selectModelTypeByProjectId(projectId);
+        } else if (DictConstants.PROJECT_TYPE_AUTO_SUB.equals(projectType)) {
+            return autoSubProjectMapper.selectModelTypeByProjectId(projectId);
+        } else {
+            throw new RuntimeException("未知项目类型:" + projectType);
+        }
+    }
 }

+ 2 - 1
simulation-resource-scheduler/src/main/java/com/css/simulation/resource/scheduler/util/TaskUtil.java

@@ -166,9 +166,10 @@ public class TaskUtil {
                 //如果项目已完成先把 pod 删除,并归还并行度
                 KubernetesUtil.deletePod2(apiClient, kubernetesConfiguration.getNamespace(), podName);
                 projectUtil.incrementOneParallelismOfGpuNode(nodeName);
+                projectUtil.releaseLicense(userId, projectUtil.getModelTypeByProjectIdAndProjectType(projectId, projectType), 1);
             } else {
                 log.info("项目 " + projectId + " 还未运行完成。");
-                projectUtil.createNextPod(projectId, nodeName, podName);
+                projectUtil.createNextPod(userId, projectId, projectType, nodeName, podName);
             }
             RedisUtil.deleteByPrefix(stringRedisTemplate, redisPrefix.getTaskMessageKey());
             RedisUtil.deleteByPrefix(stringRedisTemplate, redisPrefix.getTaskPodKey());