moss
54class Engine: 55 """ 56 Moss Engine 57 58 NOTE: Cannot create multiple Engines on different device. For that purpose, use `moss.parallel.ParallelEngine` instead. 59 """ 60 61 __version__ = _moss.__version__ 62 63 def __init__( 64 self, 65 name: str, 66 map_file: str, 67 person_file: str, 68 start_step: int = 0, 69 step_interval: float = 1, 70 seed: int = 43, 71 verbose_level=Verbosity.NO_OUTPUT, 72 person_limit: int = -1, 73 junction_yellow_time: float = 0, 74 phase_pressure_coeff: float = 1.5, 75 speed_stat_interval: int = 0, 76 output_dir: str = "", 77 out_xmin: float = -1e999, 78 out_ymin: float = -1e999, 79 out_xmax: float = 1e999, 80 out_ymax: float = 1e999, 81 device: int = 0, 82 device_mem: float = 0, 83 ): 84 """ 85 Args: 86 - name: The name of the task (for directory naming in output) 87 - map_file: The path to the map file (Protobuf format) 88 - person_file: The path to the person file (Protobuf format) 89 - start_step: The starting step of the simulation 90 - step_interval: The interval of each step (unit: seconds) 91 - seed: The random seed 92 - verbose_level: The verbosity level 93 - person_limit: The maximum number of persons to simulate (-1 means no limit) 94 - junction_yellow_time: The yellow time of the junction traffic light 95 - phase_pressure_coeff: The coefficient of the phase pressure 96 - speed_stat_interval: The interval of speed statistics. Set to `0` to disable speed statistics. 97 - output_dir: The AVRO output directory 98 - out_xmin: The minimum x coordinate of the output bounding box 99 - out_ymin: The minimum y coordinate of the output bounding box 100 - out_xmax: The maximum x coordinate of the output bounding box 101 - out_ymax: The maximum y coordinate of the output bounding box 102 - device: The CUDA device index 103 - device_mem: The memory limit of the CUDA device (unit: GiB). Set to `0` to use the self-adaptive mode (80% of the free GPU memory). 104 """ 105 106 self._fetched_persons = None 107 self._fetched_lanes = None 108 109 assert junction_yellow_time >= 0 110 if not hasattr(_thread_local, "device"): 111 _thread_local.device = device 112 elif _thread_local.device != device: 113 raise RuntimeError( 114 "Cannot create multiple Engines on different device! Use moss.parallel.ParallelEngine instead." 115 ) 116 self.speed_stat_interval = speed_stat_interval 117 """ 118 The interval of speed statistics. Set to `0` to disable speed statistics. 119 """ 120 if speed_stat_interval < 0: 121 raise ValueError("Cannot set speed_stat_interval to be less than 0") 122 self._e = _moss.Engine( 123 name, 124 map_file, 125 person_file, 126 start_step, 127 step_interval, 128 seed, 129 verbose_level.value, 130 person_limit, 131 junction_yellow_time, 132 phase_pressure_coeff, 133 speed_stat_interval, 134 output_dir, 135 out_xmin, 136 out_ymin, 137 out_xmax, 138 out_ymax, 139 device, 140 device_mem, 141 ) 142 self._map = Map() 143 with open(map_file, "rb") as f: 144 self._map.ParseFromString(f.read()) 145 self.id2lanes = {lane.id: lane for lane in self._map.lanes} 146 """ 147 Dictionary of lanes (Protobuf format) indexed by lane id 148 """ 149 self.id2roads = {road.id: road for road in self._map.roads} 150 """ 151 Dictionary of roads (Protobuf format) indexed by road id 152 """ 153 self.id2junctions = {junction.id: junction for junction in self._map.junctions} 154 """ 155 Dictionary of junctions (Protobuf format) indexed by junction id 156 """ 157 self.id2aois = {aoi.id: aoi for aoi in self._map.aois} 158 """ 159 Dictionary of AOIs (Protobuf format) indexed by AOI id 160 """ 161 162 self.lane_index2id = self.fetch_lanes()["id"] 163 """ 164 Numpy array of lane ids indexed by lane index 165 """ 166 self.junc_index2id = self._e.get_junction_ids() 167 """ 168 Numpy array of junction ids indexed by junction index 169 """ 170 self.road_index2id = self._e.get_road_ids() 171 """ 172 Numpy array of road ids indexed by road index 173 """ 174 self.lane_id2index = {v.item(): k for k, v in enumerate(self.lane_index2id)} 175 """ 176 Dictionary of lane index indexed by lane id 177 """ 178 self.junc_id2index = {v.item(): k for k, v in enumerate(self.junc_index2id)} 179 """ 180 Dictionary of junction id indexed by junction index 181 """ 182 self.road_id2index = {v.item(): k for k, v in enumerate(self.road_index2id)} 183 """ 184 Dictionary of road index indexed by road id 185 """ 186 187 self._persons = Persons() 188 with open(person_file, "rb") as f: 189 self._persons.ParseFromString(f.read()) 190 self._map_bbox = (out_xmin, out_ymin, out_xmax, out_ymax) 191 self.start_step = start_step 192 """ 193 The starting step of the simulation 194 """ 195 196 self.device = device 197 """ 198 The CUDA device index 199 """ 200 201 @property 202 def person_count(self) -> int: 203 """ 204 The number of vehicles in the agent file 205 """ 206 return len(self._persons.persons) 207 208 @property 209 def lane_count(self) -> int: 210 """ 211 The number of lanes 212 """ 213 return len(self.id2lanes) 214 215 @property 216 def road_count(self) -> int: 217 """ 218 The number of roads 219 """ 220 return len(self.id2roads) 221 222 @property 223 def junction_count(self) -> int: 224 """ 225 The number of junctions 226 """ 227 return len(self.id2junctions) 228 229 def get_map(self, dict_return: bool = True) -> Union[Map, Dict]: 230 """ 231 Get the Map object. 232 Map is a protobuf message defined in `pycityproto.city.map.v2.map_pb2` in the `pycityproto` package. 233 The documentation url is https://docs.fiblab.net/cityproto#city.map.v2.Map 234 235 Args: 236 - dict_return: Whether to return the object as a dictionary 237 238 Returns: 239 - The Map object or the dictionary 240 """ 241 if dict_return: 242 return pb2dict(self._map) 243 else: 244 return self._map 245 246 def get_persons(self, dict_return: bool = True) -> Union[Persons, Dict]: 247 """ 248 Get the Persons object. 249 Persons is a protobuf message defined in `pycityproto.city.person.v2.person_pb2` in the `pycityproto` package. 250 The documentation url is https://docs.fiblab.net/cityproto#city.person.v2.Persons 251 252 Args: 253 - dict_return: Whether to return the object as a dictionary 254 255 Returns: 256 - The Persons object or the dictionary 257 """ 258 if dict_return: 259 return pb2dict(self._persons) 260 else: 261 return self._persons 262 263 def get_current_time(self) -> float: 264 """ 265 Get the current time 266 """ 267 return self._e.get_current_time() 268 269 def fetch_persons(self) -> Dict[str, NDArray]: 270 """ 271 Fetch the persons' information. 272 273 The result values is a dictionary with the following keys: 274 - id: The id of the person 275 - status: The status of the person 276 - lane_id: The id of the lane the person is on 277 - lane_parent_id: The id of the road the lane belongs to 278 - s: The s value of the person 279 - aoi_id: The id of the AOI the person is in 280 - v: The velocity of the person 281 - shadow_lane_id: The id of the shadow lane the person is on 282 - shadow_s: The s value of the shadow lane 283 - lc_yaw: The yaw of the lane change 284 - lc_completed_ratio: The completed ratio of the lane change 285 - is_forward: Whether the person is moving forward 286 - x: The x coordinate of the person 287 - y: The y coordinate of the person 288 - dir: The direction of the person 289 - schedule_index: The index of the schedule 290 - trip_index: The index of the trip 291 - departure_time: The departure time of the person 292 - traveling_time: The traveling time of the person 293 - total_distance: The total distance of the person 294 295 We strongly recommend using `pd.DataFrame(e.fetch_persons())` to convert the result to a DataFrame for better visualization and analysis. 296 """ 297 if self._fetched_persons is None: 298 ( 299 ids, 300 statuses, 301 lane_ids, 302 lane_parent_ids, 303 ss, 304 aoi_ids, 305 vs, 306 shadow_lane_ids, 307 shadow_ss, 308 lc_yaws, 309 lc_completed_ratios, 310 is_forwards, 311 xs, 312 ys, 313 dirs, 314 schedule_indexs, 315 trip_indexs, 316 departure_times, 317 traveling_times, 318 total_distances, 319 ) = self._e.fetch_persons() 320 self._fetched_persons = { 321 "id": ids, 322 "status": statuses, 323 "lane_id": lane_ids, 324 "lane_parent_id": lane_parent_ids, 325 "s": ss, 326 "aoi_id": aoi_ids, 327 "v": vs, 328 "shadow_lane_id": shadow_lane_ids, 329 "shadow_s": shadow_ss, 330 "lc_yaw": lc_yaws, 331 "lc_completed_ratio": lc_completed_ratios, 332 "is_forward": is_forwards, 333 "x": xs, 334 "y": ys, 335 "dir": dirs, 336 "schedule_index": schedule_indexs, 337 "trip_index": trip_indexs, 338 "departure_time": departure_times, 339 "traveling_time": traveling_times, 340 "total_distance": total_distances, 341 } 342 return self._fetched_persons 343 344 def fetch_lanes(self) -> Dict[str, NDArray]: 345 """ 346 Fetch the lanes' information. 347 348 The result values is a dictionary with the following keys: 349 - id: The id of the lane 350 - status: The status of the lane 351 - v_avg: The average speed of the lane 352 353 We strongly recommend using `pd.DataFrame(e.fetch_lanes())` to convert the result to a DataFrame for better visualization and analysis. 354 """ 355 if self._fetched_lanes is None: 356 ids, statuses, v_avgs = self._e.fetch_lanes() 357 self._fetched_lanes = { 358 "id": ids, 359 "status": statuses, 360 "v_avg": v_avgs, 361 } 362 return self._fetched_lanes 363 364 def get_running_person_count(self) -> int: 365 """ 366 Get the total number of running persons (including driving and walking) 367 """ 368 persons = self.fetch_persons() 369 status: NDArray[np.uint8] = persons["status"] 370 return ( 371 status == PersonStatus.DRIVING.value | status == PersonStatus.WALKING.value 372 ).sum() 373 374 def get_lane_statuses(self) -> NDArray[np.int8]: 375 """ 376 Get the traffic light status of each lane, `0`-green / `1`-yellow / `2`-red / `3`-restriction. 377 The lane id of the entry `i` can be obtained by `e.lane_index2id[i]`. 378 """ 379 lanes = self.fetch_lanes() 380 return lanes["status"] 381 382 def get_lane_waiting_vehicle_counts( 383 self, speed_threshold: float = 0.1 384 ) -> Dict[int, int]: 385 """ 386 Get the number of vehicles of each lane with speed lower than `speed_threshold` 387 388 Returns: 389 - Dict: lane id -> number of vehicles 390 """ 391 392 persons = self.fetch_persons() 393 lane_id = persons["lane_id"] 394 status = persons["status"] 395 v = persons["v"] 396 filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold) 397 filtered_lane_id = lane_id[filter] 398 # count for the lane id 399 unique, counts = np.unique(filtered_lane_id, return_counts=True) 400 401 return dict(zip(unique, counts)) 402 403 def get_lane_waiting_at_end_vehicle_counts( 404 self, speed_threshold: float = 0.1, distance_to_end: float = 100 405 ) -> Dict[int, int]: 406 """ 407 Get the number of vehicles of each lane with speed lower than `speed_threshold` and distance to end lower than `distance_to_end` 408 409 Returns: 410 - Dict: lane id -> number of vehicles 411 """ 412 413 persons = self.fetch_persons() 414 lane_id = persons["lane_id"] 415 status = persons["status"] 416 v = persons["v"] 417 s = persons["s"] 418 filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold) 419 filtered_lane_id = lane_id[filter] 420 filtered_s = s[filter] 421 # find the distance to the end of the lane 422 lane_ids_for_count = [] 423 for i, s in zip(filtered_lane_id, filtered_s): 424 if self.id2lanes[i].length - s < distance_to_end: 425 lane_ids_for_count.append(i) 426 # count for the lane id 427 unique, counts = np.unique(lane_ids_for_count, return_counts=True) 428 return dict(zip(unique, counts)) 429 430 def get_lane_ids(self) -> NDArray[np.int32]: 431 """ 432 Get the ids of the lanes as a numpy array 433 """ 434 return self.lane_index2id 435 436 def get_lane_average_vehicle_speed(self, lane_index: int) -> float: 437 """ 438 Get the average speed of the vehicles on the lane `lane_index` 439 """ 440 if self.speed_stat_interval == 0: 441 raise RuntimeError( 442 "Please set speed_stat_interval to enable speed statistics" 443 ) 444 lanes = self.fetch_lanes() 445 v_args: NDArray[np.float32] = lanes["v_avg"] 446 return v_args[lane_index].item() 447 448 def get_junction_ids(self) -> NDArray[np.int32]: 449 """ 450 Get the ids of the junctions 451 """ 452 return self.junc_index2id 453 454 def get_junction_phase_lanes(self) -> List[List[Tuple[List[int], List[int]]]]: 455 """ 456 Get the `index` of the `in` and `out` lanes of each phase of each junction 457 458 Examples: TODO 459 """ 460 return self._e.get_junction_phase_lanes() 461 462 def get_junction_phase_ids(self) -> NDArray[np.int32]: 463 """ 464 Get the phase id of each junction, `-1` if it has no traffic lights. 465 The junction id of the entry `i` can be obtained by `e.junc_index2id[i]`. 466 """ 467 return self._e.get_junction_phase_ids() 468 469 def get_junction_phase_counts(self) -> NDArray[np.int32]: 470 """ 471 Get the number of available phases of each junction. 472 The junction id of the entry `i` can be obtained by `e.junc_index2id[i]`. 473 """ 474 return self._e.get_junction_phase_counts() 475 476 def get_junction_dynamic_roads(self) -> List[List[int]]: 477 """ 478 Get the ids of the dynamic roads connected to each junction. 479 The junction id of the entry `i` can be obtained by `e.junc_index2id 480 """ 481 return self._e.get_junction_dynamic_roads() 482 483 def get_road_lane_plans(self, road_index: int) -> List[List[slice]]: 484 """ 485 Get the dynamic lane plan of the road `road_index`, 486 represented as list of lane groups: 487 ``` 488 [ 489 [slice(lane_start, lane_end), ...] 490 ] 491 ``` 492 """ 493 return [ 494 [slice(a, b) for a, b in i] for i in self._e.get_road_lane_plans(road_index) 495 ] 496 497 def get_road_average_vehicle_speed(self, road_index: int) -> float: 498 """ 499 Get the average speed of the vehicles on the road `road_index` 500 """ 501 if self.speed_stat_interval == 0: 502 raise RuntimeError( 503 "Please set speed_stat_interval to enable speed statistics" 504 ) 505 lanes = self.fetch_lanes() 506 road_id = self.road_index2id[road_index] 507 lane_ids = self.id2roads[road_id].lane_ids 508 lane_indexes = [self.lane_id2index[i] for i in lane_ids] 509 v_args: NDArray[np.float32] = lanes["v_avg"] 510 return v_args[lane_indexes].mean().item() 511 512 def get_finished_person_count(self) -> int: 513 """ 514 Get the number of the finished persons 515 """ 516 persons = self.fetch_persons() 517 status: NDArray[np.uint8] = persons["status"] 518 return (status == PersonStatus.FINISHED.value).sum() 519 520 def get_finished_person_average_traveling_time(self) -> float: 521 """ 522 Get the average traveling time of the finished persons 523 """ 524 persons = self.fetch_persons() 525 status: NDArray[np.uint8] = persons["status"] 526 traveling_time = persons["traveling_time"] 527 return traveling_time[status == PersonStatus.FINISHED.value].mean() 528 529 def get_running_person_average_traveling_time(self) -> float: 530 """ 531 Get the average traveling time of the running persons 532 """ 533 persons = self.fetch_persons() 534 status: NDArray[np.uint8] = persons["status"] 535 traveling_time = persons["traveling_time"] 536 return traveling_time[status == PersonStatus.DRIVING.value].mean() 537 538 def get_departed_person_average_traveling_time(self) -> float: 539 """ 540 Get the average traveling time of the departed persons (running+finished) 541 """ 542 persons = self.fetch_persons() 543 status: NDArray[np.uint8] = persons["status"] 544 traveling_time = persons["traveling_time"] 545 return traveling_time[status != PersonStatus.SLEEP.value].mean() 546 547 def get_road_lane_plan_index(self, road_index: int) -> int: 548 """ 549 Get the lane plan of road `road_index` 550 """ 551 return self._e.get_road_lane_plan_index(road_index) 552 553 def get_road_vehicle_counts(self) -> Dict[int, int]: 554 """ 555 Get the number of vehicles of each road 556 557 Returns: 558 - Dict: road id -> number of vehicles 559 """ 560 persons = self.fetch_persons() 561 road_id = persons["lane_parent_id"] 562 status = persons["status"] 563 filter = status == PersonStatus.DRIVING.value 564 filtered_road_id = road_id[filter] 565 # count for the road id 566 unique, counts = np.unique(filtered_road_id, return_counts=True) 567 return dict(zip(unique, counts)) 568 569 def get_road_waiting_vehicle_counts( 570 self, speed_threshold: float = 0.1 571 ) -> Dict[int, int]: 572 """ 573 Get the number of vehicles with speed lower than `speed_threshold` of each road 574 575 Returns: 576 - Dict: road id -> number of vehicles 577 """ 578 579 persons = self.fetch_persons() 580 road_id = persons["lane_parent_id"] 581 status = persons["status"] 582 v = persons["v"] 583 filter = ( 584 (status == PersonStatus.DRIVING.value) 585 & (v < speed_threshold) 586 & (road_id < 3_0000_0000) # the road id ranges [2_0000_0000, 3_0000_0000) 587 ) 588 filtered_road_id = road_id[filter] 589 # count for the road id 590 unique, counts = np.unique(filtered_road_id, return_counts=True) 591 return dict(zip(unique, counts)) 592 593 def set_tl_policy(self, junction_index: int, policy: TlPolicy): 594 """ 595 Set the traffic light policy of junction `junction_index` to `policy` 596 """ 597 self._e.set_tl_policy(junction_index, policy.value) 598 599 def set_tl_policy_batch(self, junction_indices: List[int], policy: TlPolicy): 600 """ 601 Set the traffic light policy of all junctions in `junction_indices` to `policy` 602 """ 603 self._e.set_tl_policy_batch(junction_indices, policy.value) 604 605 def set_tl_duration(self, junction_index: int, duration: int): 606 """ 607 Set the traffic light switch duration of junction `junction_index` to `duration` 608 609 NOTE: This is only effective for `TlPolicy.FIXED_TIME` and `TlPolicy.MAX_PRESSURE`. 610 611 NOTE: Set duration to `0` to use the predefined duration in the `map_file` 612 """ 613 self._e.set_tl_duration(junction_index, duration) 614 615 def set_tl_duration_batch(self, junction_indices: List[int], duration: int): 616 """ 617 Set the traffic light switch duration of all junctions in `junction_indices` to `duration` 618 619 NOTE: This is only effective for `TlPolicy.FIXED_TIME` and `TlPolicy.MAX_PRESSURE` 620 621 NOTE: Set duration to `0` to use the predefined duration in the `map_file` 622 """ 623 self._e.set_tl_duration_batch(junction_indices, duration) 624 625 def set_tl_phase(self, junction_index: Union[str, int], phase_index: int): 626 """ 627 Set the phase of `junction_index` to `phase_index` 628 """ 629 self._e.set_tl_phase(junction_index, phase_index) 630 631 def set_tl_phase_batch(self, junction_indices: List[int], phase_indices: List[int]): 632 """ 633 Set the phase of `junction_index` to `phase_index` in batch 634 """ 635 assert len(junction_indices) == len(phase_indices) 636 self._e.set_tl_phase_batch(junction_indices, phase_indices) 637 638 def set_road_lane_plan(self, road_index: int, plan_index: int): 639 """ 640 Set the lane plan of road `road_index` 641 """ 642 self._e.set_road_lane_plan(road_index, plan_index) 643 644 def set_road_lane_plan_batch( 645 self, road_indices: List[int], plan_indices: List[int] 646 ): 647 """ 648 Set the lane plan of road `road_index` 649 """ 650 assert len(road_indices) == len(plan_indices) 651 self._e.set_road_lane_plan_batch(road_indices, plan_indices) 652 653 def set_lane_restriction(self, lane_index: int, flag: bool): 654 """ 655 Set the restriction state of lane `lane_index` 656 """ 657 self._e.set_lane_restriction(lane_index, flag) 658 659 def set_lane_restriction_batch(self, lane_indices: List[int], flags: List[bool]): 660 """ 661 Set the restriction state of lane `lane_index` 662 """ 663 assert len(lane_indices) == len(flags) 664 self._e.set_lane_restriction_batch(lane_indices, flags) 665 666 def set_lane_max_speed(self, lane_index: int, max_speed: float): 667 """ 668 Set the max_speed of lane `lane_index` 669 """ 670 self._e.set_lane_max_speed(lane_index, max_speed) 671 672 def set_lane_max_speed_batch( 673 self, lane_indices: List[int], max_speeds: Union[float, List[float]] 674 ): 675 """ 676 Set the max_speed of lane `lane_index` 677 """ 678 if hasattr(max_speeds, "__len__"): 679 assert len(lane_indices) == len(max_speeds) 680 else: 681 max_speeds = [max_speeds] * len(lane_indices) 682 self._e.set_lane_max_speed_batch(lane_indices, max_speeds) 683 684 def next_step(self, n=1): 685 """ 686 Move forward `n` steps 687 """ 688 self._fetched_persons = None 689 self._fetched_lanes = None 690 self._e.next_step(n)
Moss Engine
NOTE: Cannot create multiple Engines on different device. For that purpose, use moss.parallel.ParallelEngine
instead.
63 def __init__( 64 self, 65 name: str, 66 map_file: str, 67 person_file: str, 68 start_step: int = 0, 69 step_interval: float = 1, 70 seed: int = 43, 71 verbose_level=Verbosity.NO_OUTPUT, 72 person_limit: int = -1, 73 junction_yellow_time: float = 0, 74 phase_pressure_coeff: float = 1.5, 75 speed_stat_interval: int = 0, 76 output_dir: str = "", 77 out_xmin: float = -1e999, 78 out_ymin: float = -1e999, 79 out_xmax: float = 1e999, 80 out_ymax: float = 1e999, 81 device: int = 0, 82 device_mem: float = 0, 83 ): 84 """ 85 Args: 86 - name: The name of the task (for directory naming in output) 87 - map_file: The path to the map file (Protobuf format) 88 - person_file: The path to the person file (Protobuf format) 89 - start_step: The starting step of the simulation 90 - step_interval: The interval of each step (unit: seconds) 91 - seed: The random seed 92 - verbose_level: The verbosity level 93 - person_limit: The maximum number of persons to simulate (-1 means no limit) 94 - junction_yellow_time: The yellow time of the junction traffic light 95 - phase_pressure_coeff: The coefficient of the phase pressure 96 - speed_stat_interval: The interval of speed statistics. Set to `0` to disable speed statistics. 97 - output_dir: The AVRO output directory 98 - out_xmin: The minimum x coordinate of the output bounding box 99 - out_ymin: The minimum y coordinate of the output bounding box 100 - out_xmax: The maximum x coordinate of the output bounding box 101 - out_ymax: The maximum y coordinate of the output bounding box 102 - device: The CUDA device index 103 - device_mem: The memory limit of the CUDA device (unit: GiB). Set to `0` to use the self-adaptive mode (80% of the free GPU memory). 104 """ 105 106 self._fetched_persons = None 107 self._fetched_lanes = None 108 109 assert junction_yellow_time >= 0 110 if not hasattr(_thread_local, "device"): 111 _thread_local.device = device 112 elif _thread_local.device != device: 113 raise RuntimeError( 114 "Cannot create multiple Engines on different device! Use moss.parallel.ParallelEngine instead." 115 ) 116 self.speed_stat_interval = speed_stat_interval 117 """ 118 The interval of speed statistics. Set to `0` to disable speed statistics. 119 """ 120 if speed_stat_interval < 0: 121 raise ValueError("Cannot set speed_stat_interval to be less than 0") 122 self._e = _moss.Engine( 123 name, 124 map_file, 125 person_file, 126 start_step, 127 step_interval, 128 seed, 129 verbose_level.value, 130 person_limit, 131 junction_yellow_time, 132 phase_pressure_coeff, 133 speed_stat_interval, 134 output_dir, 135 out_xmin, 136 out_ymin, 137 out_xmax, 138 out_ymax, 139 device, 140 device_mem, 141 ) 142 self._map = Map() 143 with open(map_file, "rb") as f: 144 self._map.ParseFromString(f.read()) 145 self.id2lanes = {lane.id: lane for lane in self._map.lanes} 146 """ 147 Dictionary of lanes (Protobuf format) indexed by lane id 148 """ 149 self.id2roads = {road.id: road for road in self._map.roads} 150 """ 151 Dictionary of roads (Protobuf format) indexed by road id 152 """ 153 self.id2junctions = {junction.id: junction for junction in self._map.junctions} 154 """ 155 Dictionary of junctions (Protobuf format) indexed by junction id 156 """ 157 self.id2aois = {aoi.id: aoi for aoi in self._map.aois} 158 """ 159 Dictionary of AOIs (Protobuf format) indexed by AOI id 160 """ 161 162 self.lane_index2id = self.fetch_lanes()["id"] 163 """ 164 Numpy array of lane ids indexed by lane index 165 """ 166 self.junc_index2id = self._e.get_junction_ids() 167 """ 168 Numpy array of junction ids indexed by junction index 169 """ 170 self.road_index2id = self._e.get_road_ids() 171 """ 172 Numpy array of road ids indexed by road index 173 """ 174 self.lane_id2index = {v.item(): k for k, v in enumerate(self.lane_index2id)} 175 """ 176 Dictionary of lane index indexed by lane id 177 """ 178 self.junc_id2index = {v.item(): k for k, v in enumerate(self.junc_index2id)} 179 """ 180 Dictionary of junction id indexed by junction index 181 """ 182 self.road_id2index = {v.item(): k for k, v in enumerate(self.road_index2id)} 183 """ 184 Dictionary of road index indexed by road id 185 """ 186 187 self._persons = Persons() 188 with open(person_file, "rb") as f: 189 self._persons.ParseFromString(f.read()) 190 self._map_bbox = (out_xmin, out_ymin, out_xmax, out_ymax) 191 self.start_step = start_step 192 """ 193 The starting step of the simulation 194 """ 195 196 self.device = device 197 """ 198 The CUDA device index 199 """
Args:
- name: The name of the task (for directory naming in output)
- map_file: The path to the map file (Protobuf format)
- person_file: The path to the person file (Protobuf format)
- start_step: The starting step of the simulation
- step_interval: The interval of each step (unit: seconds)
- seed: The random seed
- verbose_level: The verbosity level
- person_limit: The maximum number of persons to simulate (-1 means no limit)
- junction_yellow_time: The yellow time of the junction traffic light
- phase_pressure_coeff: The coefficient of the phase pressure
- speed_stat_interval: The interval of speed statistics. Set to
0
to disable speed statistics. - output_dir: The AVRO output directory
- out_xmin: The minimum x coordinate of the output bounding box
- out_ymin: The minimum y coordinate of the output bounding box
- out_xmax: The maximum x coordinate of the output bounding box
- out_ymax: The maximum y coordinate of the output bounding box
- device: The CUDA device index
- device_mem: The memory limit of the CUDA device (unit: GiB). Set to
0
to use the self-adaptive mode (80% of the free GPU memory).
201 @property 202 def person_count(self) -> int: 203 """ 204 The number of vehicles in the agent file 205 """ 206 return len(self._persons.persons)
The number of vehicles in the agent file
208 @property 209 def lane_count(self) -> int: 210 """ 211 The number of lanes 212 """ 213 return len(self.id2lanes)
The number of lanes
215 @property 216 def road_count(self) -> int: 217 """ 218 The number of roads 219 """ 220 return len(self.id2roads)
The number of roads
222 @property 223 def junction_count(self) -> int: 224 """ 225 The number of junctions 226 """ 227 return len(self.id2junctions)
The number of junctions
229 def get_map(self, dict_return: bool = True) -> Union[Map, Dict]: 230 """ 231 Get the Map object. 232 Map is a protobuf message defined in `pycityproto.city.map.v2.map_pb2` in the `pycityproto` package. 233 The documentation url is https://docs.fiblab.net/cityproto#city.map.v2.Map 234 235 Args: 236 - dict_return: Whether to return the object as a dictionary 237 238 Returns: 239 - The Map object or the dictionary 240 """ 241 if dict_return: 242 return pb2dict(self._map) 243 else: 244 return self._map
Get the Map object.
Map is a protobuf message defined in pycityproto.city.map.v2.map_pb2
in the pycityproto
package.
The documentation url is https://docs.fiblab.net/cityproto#city.map.v2.Map
Args:
- dict_return: Whether to return the object as a dictionary
Returns:
- The Map object or the dictionary
246 def get_persons(self, dict_return: bool = True) -> Union[Persons, Dict]: 247 """ 248 Get the Persons object. 249 Persons is a protobuf message defined in `pycityproto.city.person.v2.person_pb2` in the `pycityproto` package. 250 The documentation url is https://docs.fiblab.net/cityproto#city.person.v2.Persons 251 252 Args: 253 - dict_return: Whether to return the object as a dictionary 254 255 Returns: 256 - The Persons object or the dictionary 257 """ 258 if dict_return: 259 return pb2dict(self._persons) 260 else: 261 return self._persons
Get the Persons object.
Persons is a protobuf message defined in pycityproto.city.person.v2.person_pb2
in the pycityproto
package.
The documentation url is https://docs.fiblab.net/cityproto#city.person.v2.Persons
Args:
- dict_return: Whether to return the object as a dictionary
Returns:
- The Persons object or the dictionary
263 def get_current_time(self) -> float: 264 """ 265 Get the current time 266 """ 267 return self._e.get_current_time()
Get the current time
269 def fetch_persons(self) -> Dict[str, NDArray]: 270 """ 271 Fetch the persons' information. 272 273 The result values is a dictionary with the following keys: 274 - id: The id of the person 275 - status: The status of the person 276 - lane_id: The id of the lane the person is on 277 - lane_parent_id: The id of the road the lane belongs to 278 - s: The s value of the person 279 - aoi_id: The id of the AOI the person is in 280 - v: The velocity of the person 281 - shadow_lane_id: The id of the shadow lane the person is on 282 - shadow_s: The s value of the shadow lane 283 - lc_yaw: The yaw of the lane change 284 - lc_completed_ratio: The completed ratio of the lane change 285 - is_forward: Whether the person is moving forward 286 - x: The x coordinate of the person 287 - y: The y coordinate of the person 288 - dir: The direction of the person 289 - schedule_index: The index of the schedule 290 - trip_index: The index of the trip 291 - departure_time: The departure time of the person 292 - traveling_time: The traveling time of the person 293 - total_distance: The total distance of the person 294 295 We strongly recommend using `pd.DataFrame(e.fetch_persons())` to convert the result to a DataFrame for better visualization and analysis. 296 """ 297 if self._fetched_persons is None: 298 ( 299 ids, 300 statuses, 301 lane_ids, 302 lane_parent_ids, 303 ss, 304 aoi_ids, 305 vs, 306 shadow_lane_ids, 307 shadow_ss, 308 lc_yaws, 309 lc_completed_ratios, 310 is_forwards, 311 xs, 312 ys, 313 dirs, 314 schedule_indexs, 315 trip_indexs, 316 departure_times, 317 traveling_times, 318 total_distances, 319 ) = self._e.fetch_persons() 320 self._fetched_persons = { 321 "id": ids, 322 "status": statuses, 323 "lane_id": lane_ids, 324 "lane_parent_id": lane_parent_ids, 325 "s": ss, 326 "aoi_id": aoi_ids, 327 "v": vs, 328 "shadow_lane_id": shadow_lane_ids, 329 "shadow_s": shadow_ss, 330 "lc_yaw": lc_yaws, 331 "lc_completed_ratio": lc_completed_ratios, 332 "is_forward": is_forwards, 333 "x": xs, 334 "y": ys, 335 "dir": dirs, 336 "schedule_index": schedule_indexs, 337 "trip_index": trip_indexs, 338 "departure_time": departure_times, 339 "traveling_time": traveling_times, 340 "total_distance": total_distances, 341 } 342 return self._fetched_persons
Fetch the persons' information.
The result values is a dictionary with the following keys:
- id: The id of the person
- status: The status of the person
- lane_id: The id of the lane the person is on
- lane_parent_id: The id of the road the lane belongs to
- s: The s value of the person
- aoi_id: The id of the AOI the person is in
- v: The velocity of the person
- shadow_lane_id: The id of the shadow lane the person is on
- shadow_s: The s value of the shadow lane
- lc_yaw: The yaw of the lane change
- lc_completed_ratio: The completed ratio of the lane change
- is_forward: Whether the person is moving forward
- x: The x coordinate of the person
- y: The y coordinate of the person
- dir: The direction of the person
- schedule_index: The index of the schedule
- trip_index: The index of the trip
- departure_time: The departure time of the person
- traveling_time: The traveling time of the person
- total_distance: The total distance of the person
We strongly recommend using pd.DataFrame(e.fetch_persons())
to convert the result to a DataFrame for better visualization and analysis.
344 def fetch_lanes(self) -> Dict[str, NDArray]: 345 """ 346 Fetch the lanes' information. 347 348 The result values is a dictionary with the following keys: 349 - id: The id of the lane 350 - status: The status of the lane 351 - v_avg: The average speed of the lane 352 353 We strongly recommend using `pd.DataFrame(e.fetch_lanes())` to convert the result to a DataFrame for better visualization and analysis. 354 """ 355 if self._fetched_lanes is None: 356 ids, statuses, v_avgs = self._e.fetch_lanes() 357 self._fetched_lanes = { 358 "id": ids, 359 "status": statuses, 360 "v_avg": v_avgs, 361 } 362 return self._fetched_lanes
Fetch the lanes' information.
The result values is a dictionary with the following keys:
- id: The id of the lane
- status: The status of the lane
- v_avg: The average speed of the lane
We strongly recommend using pd.DataFrame(e.fetch_lanes())
to convert the result to a DataFrame for better visualization and analysis.
364 def get_running_person_count(self) -> int: 365 """ 366 Get the total number of running persons (including driving and walking) 367 """ 368 persons = self.fetch_persons() 369 status: NDArray[np.uint8] = persons["status"] 370 return ( 371 status == PersonStatus.DRIVING.value | status == PersonStatus.WALKING.value 372 ).sum()
Get the total number of running persons (including driving and walking)
374 def get_lane_statuses(self) -> NDArray[np.int8]: 375 """ 376 Get the traffic light status of each lane, `0`-green / `1`-yellow / `2`-red / `3`-restriction. 377 The lane id of the entry `i` can be obtained by `e.lane_index2id[i]`. 378 """ 379 lanes = self.fetch_lanes() 380 return lanes["status"]
Get the traffic light status of each lane, 0
-green / 1
-yellow / 2
-red / 3
-restriction.
The lane id of the entry i
can be obtained by e.lane_index2id[i]
.
382 def get_lane_waiting_vehicle_counts( 383 self, speed_threshold: float = 0.1 384 ) -> Dict[int, int]: 385 """ 386 Get the number of vehicles of each lane with speed lower than `speed_threshold` 387 388 Returns: 389 - Dict: lane id -> number of vehicles 390 """ 391 392 persons = self.fetch_persons() 393 lane_id = persons["lane_id"] 394 status = persons["status"] 395 v = persons["v"] 396 filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold) 397 filtered_lane_id = lane_id[filter] 398 # count for the lane id 399 unique, counts = np.unique(filtered_lane_id, return_counts=True) 400 401 return dict(zip(unique, counts))
Get the number of vehicles of each lane with speed lower than speed_threshold
Returns:
- Dict: lane id -> number of vehicles
403 def get_lane_waiting_at_end_vehicle_counts( 404 self, speed_threshold: float = 0.1, distance_to_end: float = 100 405 ) -> Dict[int, int]: 406 """ 407 Get the number of vehicles of each lane with speed lower than `speed_threshold` and distance to end lower than `distance_to_end` 408 409 Returns: 410 - Dict: lane id -> number of vehicles 411 """ 412 413 persons = self.fetch_persons() 414 lane_id = persons["lane_id"] 415 status = persons["status"] 416 v = persons["v"] 417 s = persons["s"] 418 filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold) 419 filtered_lane_id = lane_id[filter] 420 filtered_s = s[filter] 421 # find the distance to the end of the lane 422 lane_ids_for_count = [] 423 for i, s in zip(filtered_lane_id, filtered_s): 424 if self.id2lanes[i].length - s < distance_to_end: 425 lane_ids_for_count.append(i) 426 # count for the lane id 427 unique, counts = np.unique(lane_ids_for_count, return_counts=True) 428 return dict(zip(unique, counts))
Get the number of vehicles of each lane with speed lower than speed_threshold
and distance to end lower than distance_to_end
Returns:
- Dict: lane id -> number of vehicles
430 def get_lane_ids(self) -> NDArray[np.int32]: 431 """ 432 Get the ids of the lanes as a numpy array 433 """ 434 return self.lane_index2id
Get the ids of the lanes as a numpy array
436 def get_lane_average_vehicle_speed(self, lane_index: int) -> float: 437 """ 438 Get the average speed of the vehicles on the lane `lane_index` 439 """ 440 if self.speed_stat_interval == 0: 441 raise RuntimeError( 442 "Please set speed_stat_interval to enable speed statistics" 443 ) 444 lanes = self.fetch_lanes() 445 v_args: NDArray[np.float32] = lanes["v_avg"] 446 return v_args[lane_index].item()
Get the average speed of the vehicles on the lane lane_index
448 def get_junction_ids(self) -> NDArray[np.int32]: 449 """ 450 Get the ids of the junctions 451 """ 452 return self.junc_index2id
Get the ids of the junctions
454 def get_junction_phase_lanes(self) -> List[List[Tuple[List[int], List[int]]]]: 455 """ 456 Get the `index` of the `in` and `out` lanes of each phase of each junction 457 458 Examples: TODO 459 """ 460 return self._e.get_junction_phase_lanes()
Get the index
of the in
and out
lanes of each phase of each junction
Examples: TODO
462 def get_junction_phase_ids(self) -> NDArray[np.int32]: 463 """ 464 Get the phase id of each junction, `-1` if it has no traffic lights. 465 The junction id of the entry `i` can be obtained by `e.junc_index2id[i]`. 466 """ 467 return self._e.get_junction_phase_ids()
Get the phase id of each junction, -1
if it has no traffic lights.
The junction id of the entry i
can be obtained by e.junc_index2id[i]
.
469 def get_junction_phase_counts(self) -> NDArray[np.int32]: 470 """ 471 Get the number of available phases of each junction. 472 The junction id of the entry `i` can be obtained by `e.junc_index2id[i]`. 473 """ 474 return self._e.get_junction_phase_counts()
Get the number of available phases of each junction.
The junction id of the entry i
can be obtained by e.junc_index2id[i]
.
476 def get_junction_dynamic_roads(self) -> List[List[int]]: 477 """ 478 Get the ids of the dynamic roads connected to each junction. 479 The junction id of the entry `i` can be obtained by `e.junc_index2id 480 """ 481 return self._e.get_junction_dynamic_roads()
Get the ids of the dynamic roads connected to each junction.
The junction id of the entry i
can be obtained by `e.junc_index2id
483 def get_road_lane_plans(self, road_index: int) -> List[List[slice]]: 484 """ 485 Get the dynamic lane plan of the road `road_index`, 486 represented as list of lane groups: 487 ``` 488 [ 489 [slice(lane_start, lane_end), ...] 490 ] 491 ``` 492 """ 493 return [ 494 [slice(a, b) for a, b in i] for i in self._e.get_road_lane_plans(road_index) 495 ]
Get the dynamic lane plan of the road road_index
,
represented as list of lane groups:
[
[slice(lane_start, lane_end), ...]
]
497 def get_road_average_vehicle_speed(self, road_index: int) -> float: 498 """ 499 Get the average speed of the vehicles on the road `road_index` 500 """ 501 if self.speed_stat_interval == 0: 502 raise RuntimeError( 503 "Please set speed_stat_interval to enable speed statistics" 504 ) 505 lanes = self.fetch_lanes() 506 road_id = self.road_index2id[road_index] 507 lane_ids = self.id2roads[road_id].lane_ids 508 lane_indexes = [self.lane_id2index[i] for i in lane_ids] 509 v_args: NDArray[np.float32] = lanes["v_avg"] 510 return v_args[lane_indexes].mean().item()
Get the average speed of the vehicles on the road road_index
512 def get_finished_person_count(self) -> int: 513 """ 514 Get the number of the finished persons 515 """ 516 persons = self.fetch_persons() 517 status: NDArray[np.uint8] = persons["status"] 518 return (status == PersonStatus.FINISHED.value).sum()
Get the number of the finished persons
520 def get_finished_person_average_traveling_time(self) -> float: 521 """ 522 Get the average traveling time of the finished persons 523 """ 524 persons = self.fetch_persons() 525 status: NDArray[np.uint8] = persons["status"] 526 traveling_time = persons["traveling_time"] 527 return traveling_time[status == PersonStatus.FINISHED.value].mean()
Get the average traveling time of the finished persons
529 def get_running_person_average_traveling_time(self) -> float: 530 """ 531 Get the average traveling time of the running persons 532 """ 533 persons = self.fetch_persons() 534 status: NDArray[np.uint8] = persons["status"] 535 traveling_time = persons["traveling_time"] 536 return traveling_time[status == PersonStatus.DRIVING.value].mean()
Get the average traveling time of the running persons
538 def get_departed_person_average_traveling_time(self) -> float: 539 """ 540 Get the average traveling time of the departed persons (running+finished) 541 """ 542 persons = self.fetch_persons() 543 status: NDArray[np.uint8] = persons["status"] 544 traveling_time = persons["traveling_time"] 545 return traveling_time[status != PersonStatus.SLEEP.value].mean()
Get the average traveling time of the departed persons (running+finished)
547 def get_road_lane_plan_index(self, road_index: int) -> int: 548 """ 549 Get the lane plan of road `road_index` 550 """ 551 return self._e.get_road_lane_plan_index(road_index)
Get the lane plan of road road_index
553 def get_road_vehicle_counts(self) -> Dict[int, int]: 554 """ 555 Get the number of vehicles of each road 556 557 Returns: 558 - Dict: road id -> number of vehicles 559 """ 560 persons = self.fetch_persons() 561 road_id = persons["lane_parent_id"] 562 status = persons["status"] 563 filter = status == PersonStatus.DRIVING.value 564 filtered_road_id = road_id[filter] 565 # count for the road id 566 unique, counts = np.unique(filtered_road_id, return_counts=True) 567 return dict(zip(unique, counts))
Get the number of vehicles of each road
Returns:
- Dict: road id -> number of vehicles
569 def get_road_waiting_vehicle_counts( 570 self, speed_threshold: float = 0.1 571 ) -> Dict[int, int]: 572 """ 573 Get the number of vehicles with speed lower than `speed_threshold` of each road 574 575 Returns: 576 - Dict: road id -> number of vehicles 577 """ 578 579 persons = self.fetch_persons() 580 road_id = persons["lane_parent_id"] 581 status = persons["status"] 582 v = persons["v"] 583 filter = ( 584 (status == PersonStatus.DRIVING.value) 585 & (v < speed_threshold) 586 & (road_id < 3_0000_0000) # the road id ranges [2_0000_0000, 3_0000_0000) 587 ) 588 filtered_road_id = road_id[filter] 589 # count for the road id 590 unique, counts = np.unique(filtered_road_id, return_counts=True) 591 return dict(zip(unique, counts))
Get the number of vehicles with speed lower than speed_threshold
of each road
Returns:
- Dict: road id -> number of vehicles
593 def set_tl_policy(self, junction_index: int, policy: TlPolicy): 594 """ 595 Set the traffic light policy of junction `junction_index` to `policy` 596 """ 597 self._e.set_tl_policy(junction_index, policy.value)
Set the traffic light policy of junction junction_index
to policy
599 def set_tl_policy_batch(self, junction_indices: List[int], policy: TlPolicy): 600 """ 601 Set the traffic light policy of all junctions in `junction_indices` to `policy` 602 """ 603 self._e.set_tl_policy_batch(junction_indices, policy.value)
Set the traffic light policy of all junctions in junction_indices
to policy
605 def set_tl_duration(self, junction_index: int, duration: int): 606 """ 607 Set the traffic light switch duration of junction `junction_index` to `duration` 608 609 NOTE: This is only effective for `TlPolicy.FIXED_TIME` and `TlPolicy.MAX_PRESSURE`. 610 611 NOTE: Set duration to `0` to use the predefined duration in the `map_file` 612 """ 613 self._e.set_tl_duration(junction_index, duration)
Set the traffic light switch duration of junction junction_index
to duration
NOTE: This is only effective for TlPolicy.FIXED_TIME
and TlPolicy.MAX_PRESSURE
.
NOTE: Set duration to 0
to use the predefined duration in the map_file
615 def set_tl_duration_batch(self, junction_indices: List[int], duration: int): 616 """ 617 Set the traffic light switch duration of all junctions in `junction_indices` to `duration` 618 619 NOTE: This is only effective for `TlPolicy.FIXED_TIME` and `TlPolicy.MAX_PRESSURE` 620 621 NOTE: Set duration to `0` to use the predefined duration in the `map_file` 622 """ 623 self._e.set_tl_duration_batch(junction_indices, duration)
Set the traffic light switch duration of all junctions in junction_indices
to duration
NOTE: This is only effective for TlPolicy.FIXED_TIME
and TlPolicy.MAX_PRESSURE
NOTE: Set duration to 0
to use the predefined duration in the map_file
625 def set_tl_phase(self, junction_index: Union[str, int], phase_index: int): 626 """ 627 Set the phase of `junction_index` to `phase_index` 628 """ 629 self._e.set_tl_phase(junction_index, phase_index)
Set the phase of junction_index
to phase_index
631 def set_tl_phase_batch(self, junction_indices: List[int], phase_indices: List[int]): 632 """ 633 Set the phase of `junction_index` to `phase_index` in batch 634 """ 635 assert len(junction_indices) == len(phase_indices) 636 self._e.set_tl_phase_batch(junction_indices, phase_indices)
Set the phase of junction_index
to phase_index
in batch
638 def set_road_lane_plan(self, road_index: int, plan_index: int): 639 """ 640 Set the lane plan of road `road_index` 641 """ 642 self._e.set_road_lane_plan(road_index, plan_index)
Set the lane plan of road road_index
644 def set_road_lane_plan_batch( 645 self, road_indices: List[int], plan_indices: List[int] 646 ): 647 """ 648 Set the lane plan of road `road_index` 649 """ 650 assert len(road_indices) == len(plan_indices) 651 self._e.set_road_lane_plan_batch(road_indices, plan_indices)
Set the lane plan of road road_index
653 def set_lane_restriction(self, lane_index: int, flag: bool): 654 """ 655 Set the restriction state of lane `lane_index` 656 """ 657 self._e.set_lane_restriction(lane_index, flag)
Set the restriction state of lane lane_index
659 def set_lane_restriction_batch(self, lane_indices: List[int], flags: List[bool]): 660 """ 661 Set the restriction state of lane `lane_index` 662 """ 663 assert len(lane_indices) == len(flags) 664 self._e.set_lane_restriction_batch(lane_indices, flags)
Set the restriction state of lane lane_index
666 def set_lane_max_speed(self, lane_index: int, max_speed: float): 667 """ 668 Set the max_speed of lane `lane_index` 669 """ 670 self._e.set_lane_max_speed(lane_index, max_speed)
Set the max_speed of lane lane_index
672 def set_lane_max_speed_batch( 673 self, lane_indices: List[int], max_speeds: Union[float, List[float]] 674 ): 675 """ 676 Set the max_speed of lane `lane_index` 677 """ 678 if hasattr(max_speeds, "__len__"): 679 assert len(lane_indices) == len(max_speeds) 680 else: 681 max_speeds = [max_speeds] * len(lane_indices) 682 self._e.set_lane_max_speed_batch(lane_indices, max_speeds)
Set the max_speed of lane lane_index
An enumeration.
An enumeration.
49class DBRecorder: 50 """ 51 DBRecorder is for web visualization and writes to Postgres Database 52 53 The table schema is as follows: 54 - meta_simple: The metadata of the simulation. 55 - name: The name of the simulation. 56 - start: The start step of the simulation. 57 - steps: The total steps of the simulation. 58 - time: The time of the simulation. 59 - total_agents: The total agents of the simulation. 60 - map: The map of the simulation. 61 - min_lng: The minimum longitude of the simulation. 62 - min_lat: The minimum latitude of the simulation. 63 - max_lng: The maximum longitude of the simulation. 64 - max_lat: The maximum latitude of the simulation. 65 - road_status_v_min: The minimum speed of the road status. 66 - road_status_interval: The interval of the road status. 67 - {output_name}_s_cars: The vehicles of the simulation. 68 - step: The step of the simulation. 69 - id: The id of the vehicle. 70 - parent_id: The parent id of the vehicle. 71 - direction: The direction of the vehicle. 72 - lng: The longitude of the vehicle. 73 - lat: The latitude of the vehicle. 74 - model: The model of the vehicle. 75 - z: The z of the vehicle. 76 - pitch: The pitch of the vehicle. 77 - v: The speed of the vehicle. 78 - {output_name}_s_people: The people of the simulation. 79 - step: The step of the simulation. 80 - id: The id of the people. 81 - parent_id: The parent id of the people. 82 - direction: The direction of the people. 83 - lng: The longitude of the people. 84 - lat: The latitude of the people. 85 - z: The z of the people. 86 - v: The speed of the people. 87 - model: The model of the people. 88 - {output_name}_s_traffic_light: The traffic lights of the simulation. 89 - step: The step of the simulation. 90 - id: The id of the traffic light. 91 - state: The state of the traffic light. 92 - lng: The longitude of the traffic light. 93 - lat: The latitude of the traffic light. 94 - {output_name}_s_road: The road status of the simulation. 95 - step: The step of the simulation. 96 - id: The id of the road. 97 - level: The level of the road. 98 - v: The speed of the road. 99 - in_vehicle_cnt: The in vehicle count of the road. 100 - out_vehicle_cnt: The out vehicle count of the road. 101 - cnt: The count of the road. 102 103 The index of the table is as follows: 104 - {output_name}_s_cars: (step, lng, lat) 105 - {output_name}_s_people: (step, lng, lat) 106 - {output_name}_s_traffic_light: (step, lng, lat) 107 - {output_name}_s_road: (step) 108 """ 109 110 def __init__(self, eng: Engine): 111 """ 112 Args: 113 - eng: The engine to be recorded. 114 """ 115 self.eng = eng 116 self.data = [] 117 118 def record(self): 119 """ 120 Record the data of the engine. 121 """ 122 self.data.append([ 123 self.eng._e.get_current_step(), 124 self.eng._e.get_output_vehicles(), 125 self.eng._e.get_output_tls(), 126 ]) 127 128 def save(self, db_url: str, mongo_map: str, output_name: str, use_tqdm=False): 129 """ 130 Save the data to the Postgres Database. 131 132 Args 133 - db_url: The URL of the Postgres Database. 134 - mongo_map: The map path of the simulation in mongodb (if you use mongodb). The format is like {db}.{coll}. 135 - output_name: The name of the simulation that will be saved to the database. 136 - use_tqdm: Whether to use tqdm or not. 137 """ 138 vehs = [] 139 tls = [] 140 xs = [] 141 ys = [] 142 proj = pyproj.Proj(self.eng._map.header.projection) 143 for step, (vs, vx, vy), (ts, tx, ty) in self.data: 144 if vs: 145 x, y = proj(vx, vy, True) 146 xs.extend(x) 147 ys.extend(y) 148 vehs.append([step, vs, x, y]) 149 if ts: 150 x, y = proj(tx, ty, True) 151 xs.extend(x) 152 ys.extend(y) 153 tls.append([step, ts, x, y]) 154 if xs: 155 min_lon, max_lon, min_lat, max_lat = min(xs), max(xs), min(ys), max(ys) 156 else: 157 x1, y1, x2, y2 = self.eng._map_bbox 158 min_lon, min_lat = proj(x1, y1, True) 159 max_lon, max_lat = proj(x2, y2, True) 160 with psycopg2.connect(db_url) as conn: 161 with conn.cursor() as cur: 162 # create table meta_simple 163 cur.execute(""" 164 CREATE TABLE IF NOT EXISTS public.meta_simple ( 165 "name" text NOT NULL, 166 "start" int4 NOT NULL, 167 steps int4 NOT NULL, 168 "time" float8 NOT NULL, 169 total_agents int4 NOT NULL, 170 "map" text NOT NULL, 171 min_lng float8 NOT NULL, 172 min_lat float8 NOT NULL, 173 max_lng float8 NOT NULL, 174 max_lat float8 NOT NULL, 175 road_status_v_min float8 NULL, 176 road_status_interval int4 NULL, 177 CONSTRAINT meta_simple_pkey PRIMARY KEY (name) 178 ); 179 """) 180 conn.commit() 181 182 # 删除指定记录 183 # delete from public.meta_simple where name='output_name'; 184 cur.execute(f"DELETE FROM public.meta_simple WHERE name='{output_name}';") 185 conn.commit() 186 187 # 插入新记录 188 # insert into public.meta_simple values ('output_name', 0, 1000, 1, 1, 'map', 0, 0, 1, 1, 0, 300); 189 cur.execute( 190 f"INSERT INTO public.meta_simple VALUES ('{output_name}', {self.eng.start_step}, {len(self.data)}, 1, 1, '{mongo_map}', {min_lon}, {min_lat}, {max_lon}, {max_lat}, 0, 300);") 191 conn.commit() 192 193 # 删除表格 194 # drop table if exists public.output_name_s_cars; 195 cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_cars;") 196 cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_people;") 197 cur.execute( 198 f"DROP TABLE IF EXISTS public.{output_name}_s_traffic_light;" 199 ) 200 cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_road;") 201 conn.commit() 202 203 # 创建表格 204 # create table public.output_name_s_cars 205 cur.execute( 206 f""" 207 CREATE TABLE public.{output_name}_s_cars ( 208 step int4 NOT NULL, 209 id int4 NOT NULL, 210 parent_id int4 NOT NULL, 211 direction float8 NOT NULL, 212 lng float8 NOT NULL, 213 lat float8 NOT NULL, 214 model text NOT NULL, 215 z float8 NOT NULL, 216 pitch float8 NOT NULL, 217 v float8 NOT NULL 218 ); 219 """ 220 ) 221 cur.execute( 222 f"CREATE INDEX {output_name}_s_cars_step_lng_lat_idx ON public.{output_name}_s_cars USING btree (step, lng, lat);" 223 ) 224 conn.commit() 225 226 # 创建表格 227 # create table public.output_name_s_people 228 cur.execute( 229 f""" 230 CREATE TABLE public.{output_name}_s_people ( 231 step int4 NOT NULL, 232 id int4 NOT NULL, 233 parent_id int4 NOT NULL, 234 direction float8 NOT NULL, 235 lng float8 NOT NULL, 236 lat float8 NOT NULL, 237 z float8 NOT NULL, 238 v float8 NOT NULL, 239 model text NOT NULL 240 ); 241 """ 242 ) 243 cur.execute( 244 f"CREATE INDEX {output_name}_s_people_step_lng_lat_idx ON public.{output_name}_s_people USING btree (step, lng, lat);" 245 ) 246 conn.commit() 247 248 # 创建表格 249 # create table public.output_name_s_traffic_light 250 cur.execute( 251 f""" 252 CREATE TABLE public.{output_name}_s_traffic_light ( 253 step int4 NOT NULL, 254 id int4 NOT NULL, 255 state int4 NOT NULL, 256 lng float8 NOT NULL, 257 lat float8 NOT NULL 258 ); 259 """ 260 ) 261 cur.execute( 262 f"CREATE INDEX {output_name}_s_traffic_light_step_lng_lat_idx ON public.{output_name}_s_traffic_light USING btree (step, lng, lat);" 263 ) 264 conn.commit() 265 266 # 创建表格 267 # create table public.output_name_s_road 268 cur.execute( 269 f""" 270 CREATE TABLE public.{output_name}_s_road ( 271 step int4 NOT NULL, 272 id int4 NOT NULL, 273 "level" int4 NOT NULL, 274 v float8 NOT NULL, 275 in_vehicle_cnt int4 NOT NULL, 276 out_vehicle_cnt int4 NOT NULL, 277 cnt int4 NOT NULL 278 ); 279 """ 280 ) 281 cur.execute( 282 f"CREATE INDEX {output_name}_s_road_step_idx ON public.{output_name}_s_road USING btree (step);" 283 ) 284 conn.commit() 285 286 cur.copy_from( 287 StringIteratorIO( 288 f"{step},{p},{l},{round(d,3)},{x},{y},'',0,0,{round(v,3)}\n" 289 for step, vs, x, y in tqdm(vehs, ncols=90, disable=not use_tqdm) 290 for (p, l, d, v), x, y in zip(vs, x, y) 291 ), 292 f"{output_name}_s_cars", 293 sep=",", 294 ) 295 cur.copy_from( 296 StringIteratorIO( 297 f"{step},{p},{s},{x},{y}\n" 298 for step, ts, x, y in tqdm(tls, ncols=90, disable=not use_tqdm) 299 for (p, s), x, y in zip(ts, x, y) 300 ), 301 f"{output_name}_s_traffic_light", 302 sep=",", 303 ) 304 conn.commit()
DBRecorder is for web visualization and writes to Postgres Database
The table schema is as follows:
- meta_simple: The metadata of the simulation.
- name: The name of the simulation.
- start: The start step of the simulation.
- steps: The total steps of the simulation.
- time: The time of the simulation.
- total_agents: The total agents of the simulation.
- map: The map of the simulation.
- min_lng: The minimum longitude of the simulation.
- min_lat: The minimum latitude of the simulation.
- max_lng: The maximum longitude of the simulation.
- max_lat: The maximum latitude of the simulation.
- road_status_v_min: The minimum speed of the road status.
- road_status_interval: The interval of the road status.
- {output_name}_s_cars: The vehicles of the simulation.
- step: The step of the simulation.
- id: The id of the vehicle.
- parent_id: The parent id of the vehicle.
- direction: The direction of the vehicle.
- lng: The longitude of the vehicle.
- lat: The latitude of the vehicle.
- model: The model of the vehicle.
- z: The z of the vehicle.
- pitch: The pitch of the vehicle.
- v: The speed of the vehicle.
- {output_name}_s_people: The people of the simulation.
- step: The step of the simulation.
- id: The id of the people.
- parent_id: The parent id of the people.
- direction: The direction of the people.
- lng: The longitude of the people.
- lat: The latitude of the people.
- z: The z of the people.
- v: The speed of the people.
- model: The model of the people.
- {output_name}_s_traffic_light: The traffic lights of the simulation.
- step: The step of the simulation.
- id: The id of the traffic light.
- state: The state of the traffic light.
- lng: The longitude of the traffic light.
- lat: The latitude of the traffic light.
- {output_name}_s_road: The road status of the simulation.
- step: The step of the simulation.
- id: The id of the road.
- level: The level of the road.
- v: The speed of the road.
- in_vehicle_cnt: The in vehicle count of the road.
- out_vehicle_cnt: The out vehicle count of the road.
- cnt: The count of the road.
The index of the table is as follows:
- {output_name}_s_cars: (step, lng, lat)
- {output_name}_s_people: (step, lng, lat)
- {output_name}_s_traffic_light: (step, lng, lat)
- {output_name}_s_road: (step)
110 def __init__(self, eng: Engine): 111 """ 112 Args: 113 - eng: The engine to be recorded. 114 """ 115 self.eng = eng 116 self.data = []
Args:
- eng: The engine to be recorded.
118 def record(self): 119 """ 120 Record the data of the engine. 121 """ 122 self.data.append([ 123 self.eng._e.get_current_step(), 124 self.eng._e.get_output_vehicles(), 125 self.eng._e.get_output_tls(), 126 ])
Record the data of the engine.
128 def save(self, db_url: str, mongo_map: str, output_name: str, use_tqdm=False): 129 """ 130 Save the data to the Postgres Database. 131 132 Args 133 - db_url: The URL of the Postgres Database. 134 - mongo_map: The map path of the simulation in mongodb (if you use mongodb). The format is like {db}.{coll}. 135 - output_name: The name of the simulation that will be saved to the database. 136 - use_tqdm: Whether to use tqdm or not. 137 """ 138 vehs = [] 139 tls = [] 140 xs = [] 141 ys = [] 142 proj = pyproj.Proj(self.eng._map.header.projection) 143 for step, (vs, vx, vy), (ts, tx, ty) in self.data: 144 if vs: 145 x, y = proj(vx, vy, True) 146 xs.extend(x) 147 ys.extend(y) 148 vehs.append([step, vs, x, y]) 149 if ts: 150 x, y = proj(tx, ty, True) 151 xs.extend(x) 152 ys.extend(y) 153 tls.append([step, ts, x, y]) 154 if xs: 155 min_lon, max_lon, min_lat, max_lat = min(xs), max(xs), min(ys), max(ys) 156 else: 157 x1, y1, x2, y2 = self.eng._map_bbox 158 min_lon, min_lat = proj(x1, y1, True) 159 max_lon, max_lat = proj(x2, y2, True) 160 with psycopg2.connect(db_url) as conn: 161 with conn.cursor() as cur: 162 # create table meta_simple 163 cur.execute(""" 164 CREATE TABLE IF NOT EXISTS public.meta_simple ( 165 "name" text NOT NULL, 166 "start" int4 NOT NULL, 167 steps int4 NOT NULL, 168 "time" float8 NOT NULL, 169 total_agents int4 NOT NULL, 170 "map" text NOT NULL, 171 min_lng float8 NOT NULL, 172 min_lat float8 NOT NULL, 173 max_lng float8 NOT NULL, 174 max_lat float8 NOT NULL, 175 road_status_v_min float8 NULL, 176 road_status_interval int4 NULL, 177 CONSTRAINT meta_simple_pkey PRIMARY KEY (name) 178 ); 179 """) 180 conn.commit() 181 182 # 删除指定记录 183 # delete from public.meta_simple where name='output_name'; 184 cur.execute(f"DELETE FROM public.meta_simple WHERE name='{output_name}';") 185 conn.commit() 186 187 # 插入新记录 188 # insert into public.meta_simple values ('output_name', 0, 1000, 1, 1, 'map', 0, 0, 1, 1, 0, 300); 189 cur.execute( 190 f"INSERT INTO public.meta_simple VALUES ('{output_name}', {self.eng.start_step}, {len(self.data)}, 1, 1, '{mongo_map}', {min_lon}, {min_lat}, {max_lon}, {max_lat}, 0, 300);") 191 conn.commit() 192 193 # 删除表格 194 # drop table if exists public.output_name_s_cars; 195 cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_cars;") 196 cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_people;") 197 cur.execute( 198 f"DROP TABLE IF EXISTS public.{output_name}_s_traffic_light;" 199 ) 200 cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_road;") 201 conn.commit() 202 203 # 创建表格 204 # create table public.output_name_s_cars 205 cur.execute( 206 f""" 207 CREATE TABLE public.{output_name}_s_cars ( 208 step int4 NOT NULL, 209 id int4 NOT NULL, 210 parent_id int4 NOT NULL, 211 direction float8 NOT NULL, 212 lng float8 NOT NULL, 213 lat float8 NOT NULL, 214 model text NOT NULL, 215 z float8 NOT NULL, 216 pitch float8 NOT NULL, 217 v float8 NOT NULL 218 ); 219 """ 220 ) 221 cur.execute( 222 f"CREATE INDEX {output_name}_s_cars_step_lng_lat_idx ON public.{output_name}_s_cars USING btree (step, lng, lat);" 223 ) 224 conn.commit() 225 226 # 创建表格 227 # create table public.output_name_s_people 228 cur.execute( 229 f""" 230 CREATE TABLE public.{output_name}_s_people ( 231 step int4 NOT NULL, 232 id int4 NOT NULL, 233 parent_id int4 NOT NULL, 234 direction float8 NOT NULL, 235 lng float8 NOT NULL, 236 lat float8 NOT NULL, 237 z float8 NOT NULL, 238 v float8 NOT NULL, 239 model text NOT NULL 240 ); 241 """ 242 ) 243 cur.execute( 244 f"CREATE INDEX {output_name}_s_people_step_lng_lat_idx ON public.{output_name}_s_people USING btree (step, lng, lat);" 245 ) 246 conn.commit() 247 248 # 创建表格 249 # create table public.output_name_s_traffic_light 250 cur.execute( 251 f""" 252 CREATE TABLE public.{output_name}_s_traffic_light ( 253 step int4 NOT NULL, 254 id int4 NOT NULL, 255 state int4 NOT NULL, 256 lng float8 NOT NULL, 257 lat float8 NOT NULL 258 ); 259 """ 260 ) 261 cur.execute( 262 f"CREATE INDEX {output_name}_s_traffic_light_step_lng_lat_idx ON public.{output_name}_s_traffic_light USING btree (step, lng, lat);" 263 ) 264 conn.commit() 265 266 # 创建表格 267 # create table public.output_name_s_road 268 cur.execute( 269 f""" 270 CREATE TABLE public.{output_name}_s_road ( 271 step int4 NOT NULL, 272 id int4 NOT NULL, 273 "level" int4 NOT NULL, 274 v float8 NOT NULL, 275 in_vehicle_cnt int4 NOT NULL, 276 out_vehicle_cnt int4 NOT NULL, 277 cnt int4 NOT NULL 278 ); 279 """ 280 ) 281 cur.execute( 282 f"CREATE INDEX {output_name}_s_road_step_idx ON public.{output_name}_s_road USING btree (step);" 283 ) 284 conn.commit() 285 286 cur.copy_from( 287 StringIteratorIO( 288 f"{step},{p},{l},{round(d,3)},{x},{y},'',0,0,{round(v,3)}\n" 289 for step, vs, x, y in tqdm(vehs, ncols=90, disable=not use_tqdm) 290 for (p, l, d, v), x, y in zip(vs, x, y) 291 ), 292 f"{output_name}_s_cars", 293 sep=",", 294 ) 295 cur.copy_from( 296 StringIteratorIO( 297 f"{step},{p},{s},{x},{y}\n" 298 for step, ts, x, y in tqdm(tls, ncols=90, disable=not use_tqdm) 299 for (p, s), x, y in zip(ts, x, y) 300 ), 301 f"{output_name}_s_traffic_light", 302 sep=",", 303 ) 304 conn.commit()
Save the data to the Postgres Database.
Args
- db_url: The URL of the Postgres Database.
- mongo_map: The map path of the simulation in mongodb (if you use mongodb). The format is like {db}.{coll}.
- output_name: The name of the simulation that will be saved to the database.
- use_tqdm: Whether to use tqdm or not.