score.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import os
  2. import sys
  3. import yaml
  4. import json
  5. sys.path.append('/home/kevin/kevin/zhaoyuan/evaluate_zhaoyuan/')
  6. from config import config
  7. class Score:
  8. def __init__(self, config_path, calculated_metrics=None):
  9. self.config_path = config_path
  10. # self.calculated_metrics = calculated_metrics
  11. self.calculated_metrics = calculated_metrics
  12. self.safety_config = self.load_config()
  13. self.level_3_merics = self._extract_level_3_metrics(self.safety_config)
  14. self.result = {}
  15. def load_config(self):
  16. with open(self.config_path, 'r') as file:
  17. return yaml.safe_load(file)
  18. def _extract_level_3_metrics(self, d):
  19. name = []
  20. for key, value in d.items():
  21. if isinstance(value, dict): # 如果值是字典,继续遍历
  22. self._extract_level_3_metrics(value)
  23. elif key == 'name': # 找到name键时,将值添加到列表
  24. name.append(value)
  25. return name
  26. def is_within_range(self, value, min_val, max_val):
  27. return min_val <= value <= max_val
  28. def evaluate_level_3(self, metrics):
  29. result3 = {}
  30. name = metrics.get('name')
  31. priority = metrics.get('priority')
  32. max_val = metrics.get('max')
  33. min_val = metrics.get('min')
  34. self.level_3_merics.append(name)
  35. metric_value = self.calculated_metrics.get(name)
  36. result3[name] = {
  37. 'result': True,
  38. 'priority': priority
  39. }
  40. if metric_value is None:
  41. return result3
  42. if not self.is_within_range(metric_value, min_val, max_val) and priority == 0:
  43. result3[name]['result'] = False
  44. elif not self.is_within_range(metric_value, min_val, max_val) and priority == 1:
  45. result3[name]['priority_1_count'] += 1
  46. # Count priority 1 failures and override result if more than 3
  47. priority_1_metrics = [v for v in result3.values() if v['priority'] == 1 and not v['result']]
  48. if len([m for m in priority_1_metrics if not m['result']]) > 3:
  49. result3[name]['result'] = False
  50. # print("------------3-------------")
  51. # print(result3)
  52. return result3
  53. def evaluate_level_2(self, metrics):
  54. result2 = {}
  55. name = metrics.get('name')
  56. priority = metrics.get('priority')
  57. result2[name] = {}
  58. for metric, sub_metrics in metrics.items():
  59. if metric not in ['name', 'priority']:
  60. result2[name].update(self.evaluate_level_3(sub_metrics))
  61. # Aggregate results for level 2 config.T0 config.T1 config.T2
  62. priority_0_count = sum(1 for v in result2[name].values() if v['priority'] == 0 and not v['result'])
  63. priority_1_count = sum(1 for v in result2[name].values() if v['priority'] == 1 and not v['result'])
  64. priority_2_count = sum(1 for v in result2[name].values() if v['priority'] == 2 and not v['result'])
  65. if priority_0_count > config.T0:
  66. result2[name]['result'] = False
  67. elif priority_1_count > config.T1:
  68. for metric in result2[name].values():
  69. metric['result'] = False
  70. elif priority_2_count > config.T2:
  71. for metric in result2[name].values():
  72. metric['result'] = False
  73. else:
  74. result2[name]['result'] = True # Default to True unless overridden
  75. result2[name]['priority'] = priority
  76. result2[name]['priority_0_count'] = priority_0_count
  77. result2[name]['priority_1_count'] = priority_1_count
  78. result2[name]['priority_2_count'] = priority_2_count
  79. # print("------------2-------------")
  80. # print(result2)
  81. return result2
  82. def evaluate_level_1(self):
  83. priority_1_count = 0
  84. name = self.safety_config.get('name')
  85. priority = self.safety_config.get('priority')
  86. result1 = {}
  87. result1[name] = {}
  88. for metric, metrics in self.safety_config.items():
  89. if metric not in ['name', 'priority']:
  90. result1[name].update(self.evaluate_level_2(metrics))
  91. # Aggregate results for level 2 config.T0 config.T1 config.T2
  92. priority_0_count = sum(1 for v in result1[name].values() if v['priority'] == 0 and not v['result'])
  93. priority_1_count = sum(1 for v in result1[name].values() if v['priority'] == 1 and not v['result'])
  94. priority_2_count = sum(1 for v in result1[name].values() if v['priority'] == 2 and not v['result'])
  95. if priority_0_count > config.T0:
  96. result1[name]['result'] = False
  97. elif priority_1_count > config.T1:
  98. for metric in result1[name].values():
  99. metric['result'] = False
  100. elif priority_2_count > config.T2:
  101. for metric in result1[name].values():
  102. metric['result'] = False
  103. else:
  104. result1[name]['result'] = True # Default to True unless overridden
  105. result1[name]['priority'] = priority
  106. result1[name]['priority_0_count'] = priority_0_count
  107. result1[name]['priority_1_count'] = priority_1_count
  108. result1[name]['priority_2_count'] = priority_2_count
  109. # print("------------2-------------")
  110. # print(result1)
  111. return result1
  112. def evaluate(self, calculated_metrics):
  113. self.calculated_metrics = calculated_metrics
  114. self.result = self.evaluate_level_1()
  115. return self.result
  116. def main():
  117. config_path = r'/home/kevin/kevin/zhaoyuan/evaluate_zhaoyuan/models/safety/safety_config.yaml'
  118. config_path1 = r'/home/kevin/kevin/zhaoyuan/evaluate_zhaoyuan/models/safety/safety_config.json'
  119. calculated_metrics = {
  120. 'TTC': 1.0,
  121. 'MTTC': 1.0,
  122. 'THW': 1.0,
  123. 'LonSD': 50.0,
  124. 'LatSD': 3.0,
  125. 'DRAC': 3.0,
  126. 'BTN': -1000.0,
  127. 'STN': 0.5,
  128. 'collisionRisk': 5.0,
  129. 'collisionSeverity': 2.0,
  130. }
  131. # evaluator = Score(config_path, calculated_metrics)
  132. evaluator = Score(config_path)
  133. result = evaluator.evaluate(calculated_metrics)
  134. with open(config_path1, 'w') as json_file:
  135. json.dump(result, json_file, indent=4) # `indent` 参数用于美化输出
  136. #print(f"Is the overall safety valid? {result['safety_indicator_name']['result'] if 'safety_indicator_name' in result else 'Unknown'}") # Replace 'safety_indicator_name' with actual top-level metric name
  137. if __name__ == '__main__':
  138. main()