moss

1from .engine import Engine, TlPolicy, Verbosity
2from .export import DBRecorder
3
4__all__ = ["Engine", "TlPolicy", "Verbosity", "DBRecorder"]
class Engine:
 54class Engine:
 55    """
 56    Moss Engine
 57
 58    NOTE: Cannot create multiple Engines on different device. For that purpose, use `moss.parallel.ParallelEngine` instead.
 59    """
 60
 61    __version__ = _moss.__version__
 62
 63    def __init__(
 64        self,
 65        name: str,
 66        map_file: str,
 67        person_file: str,
 68        start_step: int = 0,
 69        step_interval: float = 1,
 70        seed: int = 43,
 71        verbose_level=Verbosity.NO_OUTPUT,
 72        person_limit: int = -1,
 73        junction_yellow_time: float = 0,
 74        phase_pressure_coeff: float = 1.5,
 75        speed_stat_interval: int = 0,
 76        output_dir: str = "",
 77        out_xmin: float = -1e999,
 78        out_ymin: float = -1e999,
 79        out_xmax: float = 1e999,
 80        out_ymax: float = 1e999,
 81        device: int = 0,
 82        device_mem: float = 0,
 83    ):
 84        """
 85        Args:
 86        - name: The name of the task (for directory naming in output)
 87        - map_file: The path to the map file (Protobuf format)
 88        - person_file: The path to the person file (Protobuf format)
 89        - start_step: The starting step of the simulation
 90        - step_interval: The interval of each step (unit: seconds)
 91        - seed: The random seed
 92        - verbose_level: The verbosity level
 93        - person_limit: The maximum number of persons to simulate (-1 means no limit)
 94        - junction_yellow_time: The yellow time of the junction traffic light
 95        - phase_pressure_coeff: The coefficient of the phase pressure
 96        - speed_stat_interval: The interval of speed statistics. Set to `0` to disable speed statistics.
 97        - output_dir: The AVRO output directory
 98        - out_xmin: The minimum x coordinate of the output bounding box
 99        - out_ymin: The minimum y coordinate of the output bounding box
100        - out_xmax: The maximum x coordinate of the output bounding box
101        - out_ymax: The maximum y coordinate of the output bounding box
102        - device: The CUDA device index
103        - device_mem: The memory limit of the CUDA device (unit: GiB). Set to `0` to use the self-adaptive mode (80% of the free GPU memory).
104        """
105
106        self._fetched_persons = None
107        self._fetched_lanes = None
108
109        assert junction_yellow_time >= 0
110        if not hasattr(_thread_local, "device"):
111            _thread_local.device = device
112        elif _thread_local.device != device:
113            raise RuntimeError(
114                "Cannot create multiple Engines on different device! Use moss.parallel.ParallelEngine instead."
115            )
116        self.speed_stat_interval = speed_stat_interval
117        """
118        The interval of speed statistics. Set to `0` to disable speed statistics.
119        """
120        if speed_stat_interval < 0:
121            raise ValueError("Cannot set speed_stat_interval to be less than 0")
122        self._e = _moss.Engine(
123            name,
124            map_file,
125            person_file,
126            start_step,
127            step_interval,
128            seed,
129            verbose_level.value,
130            person_limit,
131            junction_yellow_time,
132            phase_pressure_coeff,
133            speed_stat_interval,
134            output_dir,
135            out_xmin,
136            out_ymin,
137            out_xmax,
138            out_ymax,
139            device,
140            device_mem,
141        )
142        self._map = Map()
143        with open(map_file, "rb") as f:
144            self._map.ParseFromString(f.read())
145        self.id2lanes = {lane.id: lane for lane in self._map.lanes}
146        """
147        Dictionary of lanes (Protobuf format) indexed by lane id
148        """
149        self.id2roads = {road.id: road for road in self._map.roads}
150        """
151        Dictionary of roads (Protobuf format) indexed by road id
152        """
153        self.id2junctions = {junction.id: junction for junction in self._map.junctions}
154        """
155        Dictionary of junctions (Protobuf format) indexed by junction id
156        """
157        self.id2aois = {aoi.id: aoi for aoi in self._map.aois}
158        """
159        Dictionary of AOIs (Protobuf format) indexed by AOI id
160        """
161
162        self.lane_index2id = self.fetch_lanes()["id"]
163        """
164        Numpy array of lane ids indexed by lane index
165        """
166        self.junc_index2id = self._e.get_junction_ids()
167        """
168        Numpy array of junction ids indexed by junction index
169        """
170        self.road_index2id = self._e.get_road_ids()
171        """
172        Numpy array of road ids indexed by road index
173        """
174        self.lane_id2index = {v.item(): k for k, v in enumerate(self.lane_index2id)}
175        """
176        Dictionary of lane index indexed by lane id
177        """
178        self.junc_id2index = {v.item(): k for k, v in enumerate(self.junc_index2id)}
179        """
180        Dictionary of junction id indexed by junction index
181        """
182        self.road_id2index = {v.item(): k for k, v in enumerate(self.road_index2id)}
183        """
184        Dictionary of road index indexed by road id
185        """
186
187        self._persons = Persons()
188        with open(person_file, "rb") as f:
189            self._persons.ParseFromString(f.read())
190        self._map_bbox = (out_xmin, out_ymin, out_xmax, out_ymax)
191        self.start_step = start_step
192        """
193        The starting step of the simulation
194        """
195
196        self.device = device
197        """
198        The CUDA device index
199        """
200
201    @property
202    def person_count(self) -> int:
203        """
204        The number of vehicles in the agent file
205        """
206        return len(self._persons.persons)
207
208    @property
209    def lane_count(self) -> int:
210        """
211        The number of lanes
212        """
213        return len(self.id2lanes)
214
215    @property
216    def road_count(self) -> int:
217        """
218        The number of roads
219        """
220        return len(self.id2roads)
221
222    @property
223    def junction_count(self) -> int:
224        """
225        The number of junctions
226        """
227        return len(self.id2junctions)
228
229    def get_map(self, dict_return: bool = True) -> Union[Map, Dict]:
230        """
231        Get the Map object.
232        Map is a protobuf message defined in `pycityproto.city.map.v2.map_pb2` in the `pycityproto` package.
233        The documentation url is https://docs.fiblab.net/cityproto#city.map.v2.Map
234
235        Args:
236        - dict_return: Whether to return the object as a dictionary
237
238        Returns:
239        - The Map object or the dictionary
240        """
241        if dict_return:
242            return pb2dict(self._map)
243        else:
244            return self._map
245
246    def get_persons(self, dict_return: bool = True) -> Union[Persons, Dict]:
247        """
248        Get the Persons object.
249        Persons is a protobuf message defined in `pycityproto.city.person.v2.person_pb2` in the `pycityproto` package.
250        The documentation url is https://docs.fiblab.net/cityproto#city.person.v2.Persons
251
252        Args:
253        - dict_return: Whether to return the object as a dictionary
254
255        Returns:
256        - The Persons object or the dictionary
257        """
258        if dict_return:
259            return pb2dict(self._persons)
260        else:
261            return self._persons
262
263    def get_current_time(self) -> float:
264        """
265        Get the current time
266        """
267        return self._e.get_current_time()
268
269    def fetch_persons(self) -> Dict[str, NDArray]:
270        """
271        Fetch the persons' information.
272
273        The result values is a dictionary with the following keys:
274        - id: The id of the person
275        - status: The status of the person
276        - lane_id: The id of the lane the person is on
277        - lane_parent_id: The id of the road the lane belongs to
278        - s: The s value of the person
279        - aoi_id: The id of the AOI the person is in
280        - v: The velocity of the person
281        - shadow_lane_id: The id of the shadow lane the person is on
282        - shadow_s: The s value of the shadow lane
283        - lc_yaw: The yaw of the lane change
284        - lc_completed_ratio: The completed ratio of the lane change
285        - is_forward: Whether the person is moving forward
286        - x: The x coordinate of the person
287        - y: The y coordinate of the person
288        - dir: The direction of the person
289        - schedule_index: The index of the schedule
290        - trip_index: The index of the trip
291        - departure_time: The departure time of the person
292        - traveling_time: The traveling time of the person
293        - total_distance: The total distance of the person
294
295        We strongly recommend using `pd.DataFrame(e.fetch_persons())` to convert the result to a DataFrame for better visualization and analysis.
296        """
297        if self._fetched_persons is None:
298            (
299                ids,
300                statuses,
301                lane_ids,
302                lane_parent_ids,
303                ss,
304                aoi_ids,
305                vs,
306                shadow_lane_ids,
307                shadow_ss,
308                lc_yaws,
309                lc_completed_ratios,
310                is_forwards,
311                xs,
312                ys,
313                dirs,
314                schedule_indexs,
315                trip_indexs,
316                departure_times,
317                traveling_times,
318                total_distances,
319            ) = self._e.fetch_persons()
320            self._fetched_persons = {
321                "id": ids,
322                "status": statuses,
323                "lane_id": lane_ids,
324                "lane_parent_id": lane_parent_ids,
325                "s": ss,
326                "aoi_id": aoi_ids,
327                "v": vs,
328                "shadow_lane_id": shadow_lane_ids,
329                "shadow_s": shadow_ss,
330                "lc_yaw": lc_yaws,
331                "lc_completed_ratio": lc_completed_ratios,
332                "is_forward": is_forwards,
333                "x": xs,
334                "y": ys,
335                "dir": dirs,
336                "schedule_index": schedule_indexs,
337                "trip_index": trip_indexs,
338                "departure_time": departure_times,
339                "traveling_time": traveling_times,
340                "total_distance": total_distances,
341            }
342        return self._fetched_persons
343
344    def fetch_lanes(self) -> Dict[str, NDArray]:
345        """
346        Fetch the lanes' information.
347
348        The result values is a dictionary with the following keys:
349        - id: The id of the lane
350        - status: The status of the lane
351        - v_avg: The average speed of the lane
352
353        We strongly recommend using `pd.DataFrame(e.fetch_lanes())` to convert the result to a DataFrame for better visualization and analysis.
354        """
355        if self._fetched_lanes is None:
356            ids, statuses, v_avgs = self._e.fetch_lanes()
357            self._fetched_lanes = {
358                "id": ids,
359                "status": statuses,
360                "v_avg": v_avgs,
361            }
362        return self._fetched_lanes
363
364    def get_running_person_count(self) -> int:
365        """
366        Get the total number of running persons (including driving and walking)
367        """
368        persons = self.fetch_persons()
369        status: NDArray[np.uint8] = persons["status"]
370        return (
371            status == PersonStatus.DRIVING.value | status == PersonStatus.WALKING.value
372        ).sum()
373
374    def get_lane_statuses(self) -> NDArray[np.int8]:
375        """
376        Get the traffic light status of each lane, `0`-green / `1`-yellow / `2`-red / `3`-restriction.
377        The lane id of the entry `i` can be obtained by `e.lane_index2id[i]`.
378        """
379        lanes = self.fetch_lanes()
380        return lanes["status"]
381
382    def get_lane_waiting_vehicle_counts(
383        self, speed_threshold: float = 0.1
384    ) -> Dict[int, int]:
385        """
386        Get the number of vehicles of each lane with speed lower than `speed_threshold`
387
388        Returns:
389        - Dict: lane id -> number of vehicles
390        """
391
392        persons = self.fetch_persons()
393        lane_id = persons["lane_id"]
394        status = persons["status"]
395        v = persons["v"]
396        filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold)
397        filtered_lane_id = lane_id[filter]
398        # count for the lane id
399        unique, counts = np.unique(filtered_lane_id, return_counts=True)
400
401        return dict(zip(unique, counts))
402
403    def get_lane_waiting_at_end_vehicle_counts(
404        self, speed_threshold: float = 0.1, distance_to_end: float = 100
405    ) -> Dict[int, int]:
406        """
407        Get the number of vehicles of each lane with speed lower than `speed_threshold` and distance to end lower than `distance_to_end`
408
409        Returns:
410        - Dict: lane id -> number of vehicles
411        """
412
413        persons = self.fetch_persons()
414        lane_id = persons["lane_id"]
415        status = persons["status"]
416        v = persons["v"]
417        s = persons["s"]
418        filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold)
419        filtered_lane_id = lane_id[filter]
420        filtered_s = s[filter]
421        # find the distance to the end of the lane
422        lane_ids_for_count = []
423        for i, s in zip(filtered_lane_id, filtered_s):
424            if self.id2lanes[i].length - s < distance_to_end:
425                lane_ids_for_count.append(i)
426        # count for the lane id
427        unique, counts = np.unique(lane_ids_for_count, return_counts=True)
428        return dict(zip(unique, counts))
429
430    def get_lane_ids(self) -> NDArray[np.int32]:
431        """
432        Get the ids of the lanes as a numpy array
433        """
434        return self.lane_index2id
435
436    def get_lane_average_vehicle_speed(self, lane_index: int) -> float:
437        """
438        Get the average speed of the vehicles on the lane `lane_index`
439        """
440        if self.speed_stat_interval == 0:
441            raise RuntimeError(
442                "Please set speed_stat_interval to enable speed statistics"
443            )
444        lanes = self.fetch_lanes()
445        v_args: NDArray[np.float32] = lanes["v_avg"]
446        return v_args[lane_index].item()
447
448    def get_junction_ids(self) -> NDArray[np.int32]:
449        """
450        Get the ids of the junctions
451        """
452        return self.junc_index2id
453
454    def get_junction_phase_lanes(self) -> List[List[Tuple[List[int], List[int]]]]:
455        """
456        Get the `index` of the `in` and `out` lanes of each phase of each junction
457
458        Examples: TODO
459        """
460        return self._e.get_junction_phase_lanes()
461
462    def get_junction_phase_ids(self) -> NDArray[np.int32]:
463        """
464        Get the phase id of each junction, `-1` if it has no traffic lights.
465        The junction id of the entry `i` can be obtained by `e.junc_index2id[i]`.
466        """
467        return self._e.get_junction_phase_ids()
468
469    def get_junction_phase_counts(self) -> NDArray[np.int32]:
470        """
471        Get the number of available phases of each junction.
472        The junction id of the entry `i` can be obtained by `e.junc_index2id[i]`.
473        """
474        return self._e.get_junction_phase_counts()
475
476    def get_junction_dynamic_roads(self) -> List[List[int]]:
477        """
478        Get the ids of the dynamic roads connected to each junction.
479        The junction id of the entry `i` can be obtained by `e.junc_index2id
480        """
481        return self._e.get_junction_dynamic_roads()
482
483    def get_road_lane_plans(self, road_index: int) -> List[List[slice]]:
484        """
485        Get the dynamic lane plan of the road `road_index`,
486        represented as list of lane groups:
487        ```
488        [
489            [slice(lane_start, lane_end), ...]
490        ]
491        ```
492        """
493        return [
494            [slice(a, b) for a, b in i] for i in self._e.get_road_lane_plans(road_index)
495        ]
496
497    def get_road_average_vehicle_speed(self, road_index: int) -> float:
498        """
499        Get the average speed of the vehicles on the road `road_index`
500        """
501        if self.speed_stat_interval == 0:
502            raise RuntimeError(
503                "Please set speed_stat_interval to enable speed statistics"
504            )
505        lanes = self.fetch_lanes()
506        road_id = self.road_index2id[road_index]
507        lane_ids = self.id2roads[road_id].lane_ids
508        lane_indexes = [self.lane_id2index[i] for i in lane_ids]
509        v_args: NDArray[np.float32] = lanes["v_avg"]
510        return v_args[lane_indexes].mean().item()
511
512    def get_finished_person_count(self) -> int:
513        """
514        Get the number of the finished persons
515        """
516        persons = self.fetch_persons()
517        status: NDArray[np.uint8] = persons["status"]
518        return (status == PersonStatus.FINISHED.value).sum()
519
520    def get_finished_person_average_traveling_time(self) -> float:
521        """
522        Get the average traveling time of the finished persons
523        """
524        persons = self.fetch_persons()
525        status: NDArray[np.uint8] = persons["status"]
526        traveling_time = persons["traveling_time"]
527        return traveling_time[status == PersonStatus.FINISHED.value].mean()
528
529    def get_running_person_average_traveling_time(self) -> float:
530        """
531        Get the average traveling time of the running persons
532        """
533        persons = self.fetch_persons()
534        status: NDArray[np.uint8] = persons["status"]
535        traveling_time = persons["traveling_time"]
536        return traveling_time[status == PersonStatus.DRIVING.value].mean()
537
538    def get_departed_person_average_traveling_time(self) -> float:
539        """
540        Get the average traveling time of the departed persons (running+finished)
541        """
542        persons = self.fetch_persons()
543        status: NDArray[np.uint8] = persons["status"]
544        traveling_time = persons["traveling_time"]
545        return traveling_time[status != PersonStatus.SLEEP.value].mean()
546
547    def get_road_lane_plan_index(self, road_index: int) -> int:
548        """
549        Get the lane plan of road `road_index`
550        """
551        return self._e.get_road_lane_plan_index(road_index)
552
553    def get_road_vehicle_counts(self) -> Dict[int, int]:
554        """
555        Get the number of vehicles of each road
556
557        Returns:
558        - Dict: road id -> number of vehicles
559        """
560        persons = self.fetch_persons()
561        road_id = persons["lane_parent_id"]
562        status = persons["status"]
563        filter = status == PersonStatus.DRIVING.value
564        filtered_road_id = road_id[filter]
565        # count for the road id
566        unique, counts = np.unique(filtered_road_id, return_counts=True)
567        return dict(zip(unique, counts))
568
569    def get_road_waiting_vehicle_counts(
570        self, speed_threshold: float = 0.1
571    ) -> Dict[int, int]:
572        """
573        Get the number of vehicles with speed lower than `speed_threshold` of each road
574
575        Returns:
576        - Dict: road id -> number of vehicles
577        """
578
579        persons = self.fetch_persons()
580        road_id = persons["lane_parent_id"]
581        status = persons["status"]
582        v = persons["v"]
583        filter = (
584            (status == PersonStatus.DRIVING.value)
585            & (v < speed_threshold)
586            & (road_id < 3_0000_0000)  # the road id ranges [2_0000_0000, 3_0000_0000)
587        )
588        filtered_road_id = road_id[filter]
589        # count for the road id
590        unique, counts = np.unique(filtered_road_id, return_counts=True)
591        return dict(zip(unique, counts))
592
593    def set_tl_policy(self, junction_index: int, policy: TlPolicy):
594        """
595        Set the traffic light policy of junction `junction_index` to `policy`
596        """
597        self._e.set_tl_policy(junction_index, policy.value)
598
599    def set_tl_policy_batch(self, junction_indices: List[int], policy: TlPolicy):
600        """
601        Set the traffic light policy of all junctions in `junction_indices` to `policy`
602        """
603        self._e.set_tl_policy_batch(junction_indices, policy.value)
604
605    def set_tl_duration(self, junction_index: int, duration: int):
606        """
607        Set the traffic light switch duration of junction `junction_index` to `duration`
608
609        NOTE: This is only effective for `TlPolicy.FIXED_TIME` and `TlPolicy.MAX_PRESSURE`.
610
611        NOTE: Set duration to `0` to use the predefined duration in the `map_file`
612        """
613        self._e.set_tl_duration(junction_index, duration)
614
615    def set_tl_duration_batch(self, junction_indices: List[int], duration: int):
616        """
617        Set the traffic light switch duration of all junctions in `junction_indices` to `duration`
618
619        NOTE: This is only effective for `TlPolicy.FIXED_TIME` and `TlPolicy.MAX_PRESSURE`
620
621        NOTE: Set duration to `0` to use the predefined duration in the `map_file`
622        """
623        self._e.set_tl_duration_batch(junction_indices, duration)
624
625    def set_tl_phase(self, junction_index: Union[str, int], phase_index: int):
626        """
627        Set the phase of `junction_index` to `phase_index`
628        """
629        self._e.set_tl_phase(junction_index, phase_index)
630
631    def set_tl_phase_batch(self, junction_indices: List[int], phase_indices: List[int]):
632        """
633        Set the phase of `junction_index` to `phase_index` in batch
634        """
635        assert len(junction_indices) == len(phase_indices)
636        self._e.set_tl_phase_batch(junction_indices, phase_indices)
637
638    def set_road_lane_plan(self, road_index: int, plan_index: int):
639        """
640        Set the lane plan of road `road_index`
641        """
642        self._e.set_road_lane_plan(road_index, plan_index)
643
644    def set_road_lane_plan_batch(
645        self, road_indices: List[int], plan_indices: List[int]
646    ):
647        """
648        Set the lane plan of road `road_index`
649        """
650        assert len(road_indices) == len(plan_indices)
651        self._e.set_road_lane_plan_batch(road_indices, plan_indices)
652
653    def set_lane_restriction(self, lane_index: int, flag: bool):
654        """
655        Set the restriction state of lane `lane_index`
656        """
657        self._e.set_lane_restriction(lane_index, flag)
658
659    def set_lane_restriction_batch(self, lane_indices: List[int], flags: List[bool]):
660        """
661        Set the restriction state of lane `lane_index`
662        """
663        assert len(lane_indices) == len(flags)
664        self._e.set_lane_restriction_batch(lane_indices, flags)
665
666    def set_lane_max_speed(self, lane_index: int, max_speed: float):
667        """
668        Set the max_speed of lane `lane_index`
669        """
670        self._e.set_lane_max_speed(lane_index, max_speed)
671
672    def set_lane_max_speed_batch(
673        self, lane_indices: List[int], max_speeds: Union[float, List[float]]
674    ):
675        """
676        Set the max_speed of lane `lane_index`
677        """
678        if hasattr(max_speeds, "__len__"):
679            assert len(lane_indices) == len(max_speeds)
680        else:
681            max_speeds = [max_speeds] * len(lane_indices)
682        self._e.set_lane_max_speed_batch(lane_indices, max_speeds)
683
684    def next_step(self, n=1):
685        """
686        Move forward `n` steps
687        """
688        self._fetched_persons = None
689        self._fetched_lanes = None
690        self._e.next_step(n)

Moss Engine

NOTE: Cannot create multiple Engines on different device. For that purpose, use moss.parallel.ParallelEngine instead.

Engine( name: str, map_file: str, person_file: str, start_step: int = 0, step_interval: float = 1, seed: int = 43, verbose_level=<Verbosity.NO_OUTPUT: 0>, person_limit: int = -1, junction_yellow_time: float = 0, phase_pressure_coeff: float = 1.5, speed_stat_interval: int = 0, output_dir: str = '', out_xmin: float = -inf, out_ymin: float = -inf, out_xmax: float = inf, out_ymax: float = inf, device: int = 0, device_mem: float = 0)
 63    def __init__(
 64        self,
 65        name: str,
 66        map_file: str,
 67        person_file: str,
 68        start_step: int = 0,
 69        step_interval: float = 1,
 70        seed: int = 43,
 71        verbose_level=Verbosity.NO_OUTPUT,
 72        person_limit: int = -1,
 73        junction_yellow_time: float = 0,
 74        phase_pressure_coeff: float = 1.5,
 75        speed_stat_interval: int = 0,
 76        output_dir: str = "",
 77        out_xmin: float = -1e999,
 78        out_ymin: float = -1e999,
 79        out_xmax: float = 1e999,
 80        out_ymax: float = 1e999,
 81        device: int = 0,
 82        device_mem: float = 0,
 83    ):
 84        """
 85        Args:
 86        - name: The name of the task (for directory naming in output)
 87        - map_file: The path to the map file (Protobuf format)
 88        - person_file: The path to the person file (Protobuf format)
 89        - start_step: The starting step of the simulation
 90        - step_interval: The interval of each step (unit: seconds)
 91        - seed: The random seed
 92        - verbose_level: The verbosity level
 93        - person_limit: The maximum number of persons to simulate (-1 means no limit)
 94        - junction_yellow_time: The yellow time of the junction traffic light
 95        - phase_pressure_coeff: The coefficient of the phase pressure
 96        - speed_stat_interval: The interval of speed statistics. Set to `0` to disable speed statistics.
 97        - output_dir: The AVRO output directory
 98        - out_xmin: The minimum x coordinate of the output bounding box
 99        - out_ymin: The minimum y coordinate of the output bounding box
100        - out_xmax: The maximum x coordinate of the output bounding box
101        - out_ymax: The maximum y coordinate of the output bounding box
102        - device: The CUDA device index
103        - device_mem: The memory limit of the CUDA device (unit: GiB). Set to `0` to use the self-adaptive mode (80% of the free GPU memory).
104        """
105
106        self._fetched_persons = None
107        self._fetched_lanes = None
108
109        assert junction_yellow_time >= 0
110        if not hasattr(_thread_local, "device"):
111            _thread_local.device = device
112        elif _thread_local.device != device:
113            raise RuntimeError(
114                "Cannot create multiple Engines on different device! Use moss.parallel.ParallelEngine instead."
115            )
116        self.speed_stat_interval = speed_stat_interval
117        """
118        The interval of speed statistics. Set to `0` to disable speed statistics.
119        """
120        if speed_stat_interval < 0:
121            raise ValueError("Cannot set speed_stat_interval to be less than 0")
122        self._e = _moss.Engine(
123            name,
124            map_file,
125            person_file,
126            start_step,
127            step_interval,
128            seed,
129            verbose_level.value,
130            person_limit,
131            junction_yellow_time,
132            phase_pressure_coeff,
133            speed_stat_interval,
134            output_dir,
135            out_xmin,
136            out_ymin,
137            out_xmax,
138            out_ymax,
139            device,
140            device_mem,
141        )
142        self._map = Map()
143        with open(map_file, "rb") as f:
144            self._map.ParseFromString(f.read())
145        self.id2lanes = {lane.id: lane for lane in self._map.lanes}
146        """
147        Dictionary of lanes (Protobuf format) indexed by lane id
148        """
149        self.id2roads = {road.id: road for road in self._map.roads}
150        """
151        Dictionary of roads (Protobuf format) indexed by road id
152        """
153        self.id2junctions = {junction.id: junction for junction in self._map.junctions}
154        """
155        Dictionary of junctions (Protobuf format) indexed by junction id
156        """
157        self.id2aois = {aoi.id: aoi for aoi in self._map.aois}
158        """
159        Dictionary of AOIs (Protobuf format) indexed by AOI id
160        """
161
162        self.lane_index2id = self.fetch_lanes()["id"]
163        """
164        Numpy array of lane ids indexed by lane index
165        """
166        self.junc_index2id = self._e.get_junction_ids()
167        """
168        Numpy array of junction ids indexed by junction index
169        """
170        self.road_index2id = self._e.get_road_ids()
171        """
172        Numpy array of road ids indexed by road index
173        """
174        self.lane_id2index = {v.item(): k for k, v in enumerate(self.lane_index2id)}
175        """
176        Dictionary of lane index indexed by lane id
177        """
178        self.junc_id2index = {v.item(): k for k, v in enumerate(self.junc_index2id)}
179        """
180        Dictionary of junction id indexed by junction index
181        """
182        self.road_id2index = {v.item(): k for k, v in enumerate(self.road_index2id)}
183        """
184        Dictionary of road index indexed by road id
185        """
186
187        self._persons = Persons()
188        with open(person_file, "rb") as f:
189            self._persons.ParseFromString(f.read())
190        self._map_bbox = (out_xmin, out_ymin, out_xmax, out_ymax)
191        self.start_step = start_step
192        """
193        The starting step of the simulation
194        """
195
196        self.device = device
197        """
198        The CUDA device index
199        """

Args:

  • name: The name of the task (for directory naming in output)
  • map_file: The path to the map file (Protobuf format)
  • person_file: The path to the person file (Protobuf format)
  • start_step: The starting step of the simulation
  • step_interval: The interval of each step (unit: seconds)
  • seed: The random seed
  • verbose_level: The verbosity level
  • person_limit: The maximum number of persons to simulate (-1 means no limit)
  • junction_yellow_time: The yellow time of the junction traffic light
  • phase_pressure_coeff: The coefficient of the phase pressure
  • speed_stat_interval: The interval of speed statistics. Set to 0 to disable speed statistics.
  • output_dir: The AVRO output directory
  • out_xmin: The minimum x coordinate of the output bounding box
  • out_ymin: The minimum y coordinate of the output bounding box
  • out_xmax: The maximum x coordinate of the output bounding box
  • out_ymax: The maximum y coordinate of the output bounding box
  • device: The CUDA device index
  • device_mem: The memory limit of the CUDA device (unit: GiB). Set to 0 to use the self-adaptive mode (80% of the free GPU memory).
speed_stat_interval

The interval of speed statistics. Set to 0 to disable speed statistics.

id2lanes

Dictionary of lanes (Protobuf format) indexed by lane id

id2roads

Dictionary of roads (Protobuf format) indexed by road id

id2junctions

Dictionary of junctions (Protobuf format) indexed by junction id

id2aois

Dictionary of AOIs (Protobuf format) indexed by AOI id

lane_index2id

Numpy array of lane ids indexed by lane index

junc_index2id

Numpy array of junction ids indexed by junction index

road_index2id

Numpy array of road ids indexed by road index

lane_id2index

Dictionary of lane index indexed by lane id

junc_id2index

Dictionary of junction id indexed by junction index

road_id2index

Dictionary of road index indexed by road id

start_step

The starting step of the simulation

device

The CUDA device index

person_count: int
201    @property
202    def person_count(self) -> int:
203        """
204        The number of vehicles in the agent file
205        """
206        return len(self._persons.persons)

The number of vehicles in the agent file

lane_count: int
208    @property
209    def lane_count(self) -> int:
210        """
211        The number of lanes
212        """
213        return len(self.id2lanes)

The number of lanes

road_count: int
215    @property
216    def road_count(self) -> int:
217        """
218        The number of roads
219        """
220        return len(self.id2roads)

The number of roads

junction_count: int
222    @property
223    def junction_count(self) -> int:
224        """
225        The number of junctions
226        """
227        return len(self.id2junctions)

The number of junctions

def get_map(self, dict_return: bool = True) -> Union[city.map.v2.map_pb2.Map, Dict]:
229    def get_map(self, dict_return: bool = True) -> Union[Map, Dict]:
230        """
231        Get the Map object.
232        Map is a protobuf message defined in `pycityproto.city.map.v2.map_pb2` in the `pycityproto` package.
233        The documentation url is https://docs.fiblab.net/cityproto#city.map.v2.Map
234
235        Args:
236        - dict_return: Whether to return the object as a dictionary
237
238        Returns:
239        - The Map object or the dictionary
240        """
241        if dict_return:
242            return pb2dict(self._map)
243        else:
244            return self._map

Get the Map object. Map is a protobuf message defined in pycityproto.city.map.v2.map_pb2 in the pycityproto package. The documentation url is https://docs.fiblab.net/cityproto#city.map.v2.Map

Args:

  • dict_return: Whether to return the object as a dictionary

Returns:

  • The Map object or the dictionary
def get_persons( self, dict_return: bool = True) -> Union[city.person.v2.person_pb2.Persons, Dict]:
246    def get_persons(self, dict_return: bool = True) -> Union[Persons, Dict]:
247        """
248        Get the Persons object.
249        Persons is a protobuf message defined in `pycityproto.city.person.v2.person_pb2` in the `pycityproto` package.
250        The documentation url is https://docs.fiblab.net/cityproto#city.person.v2.Persons
251
252        Args:
253        - dict_return: Whether to return the object as a dictionary
254
255        Returns:
256        - The Persons object or the dictionary
257        """
258        if dict_return:
259            return pb2dict(self._persons)
260        else:
261            return self._persons

Get the Persons object. Persons is a protobuf message defined in pycityproto.city.person.v2.person_pb2 in the pycityproto package. The documentation url is https://docs.fiblab.net/cityproto#city.person.v2.Persons

Args:

  • dict_return: Whether to return the object as a dictionary

Returns:

  • The Persons object or the dictionary
def get_current_time(self) -> float:
263    def get_current_time(self) -> float:
264        """
265        Get the current time
266        """
267        return self._e.get_current_time()

Get the current time

def fetch_persons(self) -> Dict[str, numpy.ndarray[Any, numpy.dtype[+_ScalarType_co]]]:
269    def fetch_persons(self) -> Dict[str, NDArray]:
270        """
271        Fetch the persons' information.
272
273        The result values is a dictionary with the following keys:
274        - id: The id of the person
275        - status: The status of the person
276        - lane_id: The id of the lane the person is on
277        - lane_parent_id: The id of the road the lane belongs to
278        - s: The s value of the person
279        - aoi_id: The id of the AOI the person is in
280        - v: The velocity of the person
281        - shadow_lane_id: The id of the shadow lane the person is on
282        - shadow_s: The s value of the shadow lane
283        - lc_yaw: The yaw of the lane change
284        - lc_completed_ratio: The completed ratio of the lane change
285        - is_forward: Whether the person is moving forward
286        - x: The x coordinate of the person
287        - y: The y coordinate of the person
288        - dir: The direction of the person
289        - schedule_index: The index of the schedule
290        - trip_index: The index of the trip
291        - departure_time: The departure time of the person
292        - traveling_time: The traveling time of the person
293        - total_distance: The total distance of the person
294
295        We strongly recommend using `pd.DataFrame(e.fetch_persons())` to convert the result to a DataFrame for better visualization and analysis.
296        """
297        if self._fetched_persons is None:
298            (
299                ids,
300                statuses,
301                lane_ids,
302                lane_parent_ids,
303                ss,
304                aoi_ids,
305                vs,
306                shadow_lane_ids,
307                shadow_ss,
308                lc_yaws,
309                lc_completed_ratios,
310                is_forwards,
311                xs,
312                ys,
313                dirs,
314                schedule_indexs,
315                trip_indexs,
316                departure_times,
317                traveling_times,
318                total_distances,
319            ) = self._e.fetch_persons()
320            self._fetched_persons = {
321                "id": ids,
322                "status": statuses,
323                "lane_id": lane_ids,
324                "lane_parent_id": lane_parent_ids,
325                "s": ss,
326                "aoi_id": aoi_ids,
327                "v": vs,
328                "shadow_lane_id": shadow_lane_ids,
329                "shadow_s": shadow_ss,
330                "lc_yaw": lc_yaws,
331                "lc_completed_ratio": lc_completed_ratios,
332                "is_forward": is_forwards,
333                "x": xs,
334                "y": ys,
335                "dir": dirs,
336                "schedule_index": schedule_indexs,
337                "trip_index": trip_indexs,
338                "departure_time": departure_times,
339                "traveling_time": traveling_times,
340                "total_distance": total_distances,
341            }
342        return self._fetched_persons

Fetch the persons' information.

The result values is a dictionary with the following keys:

  • id: The id of the person
  • status: The status of the person
  • lane_id: The id of the lane the person is on
  • lane_parent_id: The id of the road the lane belongs to
  • s: The s value of the person
  • aoi_id: The id of the AOI the person is in
  • v: The velocity of the person
  • shadow_lane_id: The id of the shadow lane the person is on
  • shadow_s: The s value of the shadow lane
  • lc_yaw: The yaw of the lane change
  • lc_completed_ratio: The completed ratio of the lane change
  • is_forward: Whether the person is moving forward
  • x: The x coordinate of the person
  • y: The y coordinate of the person
  • dir: The direction of the person
  • schedule_index: The index of the schedule
  • trip_index: The index of the trip
  • departure_time: The departure time of the person
  • traveling_time: The traveling time of the person
  • total_distance: The total distance of the person

We strongly recommend using pd.DataFrame(e.fetch_persons()) to convert the result to a DataFrame for better visualization and analysis.

def fetch_lanes(self) -> Dict[str, numpy.ndarray[Any, numpy.dtype[+_ScalarType_co]]]:
344    def fetch_lanes(self) -> Dict[str, NDArray]:
345        """
346        Fetch the lanes' information.
347
348        The result values is a dictionary with the following keys:
349        - id: The id of the lane
350        - status: The status of the lane
351        - v_avg: The average speed of the lane
352
353        We strongly recommend using `pd.DataFrame(e.fetch_lanes())` to convert the result to a DataFrame for better visualization and analysis.
354        """
355        if self._fetched_lanes is None:
356            ids, statuses, v_avgs = self._e.fetch_lanes()
357            self._fetched_lanes = {
358                "id": ids,
359                "status": statuses,
360                "v_avg": v_avgs,
361            }
362        return self._fetched_lanes

Fetch the lanes' information.

The result values is a dictionary with the following keys:

  • id: The id of the lane
  • status: The status of the lane
  • v_avg: The average speed of the lane

We strongly recommend using pd.DataFrame(e.fetch_lanes()) to convert the result to a DataFrame for better visualization and analysis.

def get_running_person_count(self) -> int:
364    def get_running_person_count(self) -> int:
365        """
366        Get the total number of running persons (including driving and walking)
367        """
368        persons = self.fetch_persons()
369        status: NDArray[np.uint8] = persons["status"]
370        return (
371            status == PersonStatus.DRIVING.value | status == PersonStatus.WALKING.value
372        ).sum()

Get the total number of running persons (including driving and walking)

def get_lane_statuses(self) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.int8]]:
374    def get_lane_statuses(self) -> NDArray[np.int8]:
375        """
376        Get the traffic light status of each lane, `0`-green / `1`-yellow / `2`-red / `3`-restriction.
377        The lane id of the entry `i` can be obtained by `e.lane_index2id[i]`.
378        """
379        lanes = self.fetch_lanes()
380        return lanes["status"]

Get the traffic light status of each lane, 0-green / 1-yellow / 2-red / 3-restriction. The lane id of the entry i can be obtained by e.lane_index2id[i].

def get_lane_waiting_vehicle_counts(self, speed_threshold: float = 0.1) -> Dict[int, int]:
382    def get_lane_waiting_vehicle_counts(
383        self, speed_threshold: float = 0.1
384    ) -> Dict[int, int]:
385        """
386        Get the number of vehicles of each lane with speed lower than `speed_threshold`
387
388        Returns:
389        - Dict: lane id -> number of vehicles
390        """
391
392        persons = self.fetch_persons()
393        lane_id = persons["lane_id"]
394        status = persons["status"]
395        v = persons["v"]
396        filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold)
397        filtered_lane_id = lane_id[filter]
398        # count for the lane id
399        unique, counts = np.unique(filtered_lane_id, return_counts=True)
400
401        return dict(zip(unique, counts))

Get the number of vehicles of each lane with speed lower than speed_threshold

Returns:

  • Dict: lane id -> number of vehicles
def get_lane_waiting_at_end_vehicle_counts( self, speed_threshold: float = 0.1, distance_to_end: float = 100) -> Dict[int, int]:
403    def get_lane_waiting_at_end_vehicle_counts(
404        self, speed_threshold: float = 0.1, distance_to_end: float = 100
405    ) -> Dict[int, int]:
406        """
407        Get the number of vehicles of each lane with speed lower than `speed_threshold` and distance to end lower than `distance_to_end`
408
409        Returns:
410        - Dict: lane id -> number of vehicles
411        """
412
413        persons = self.fetch_persons()
414        lane_id = persons["lane_id"]
415        status = persons["status"]
416        v = persons["v"]
417        s = persons["s"]
418        filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold)
419        filtered_lane_id = lane_id[filter]
420        filtered_s = s[filter]
421        # find the distance to the end of the lane
422        lane_ids_for_count = []
423        for i, s in zip(filtered_lane_id, filtered_s):
424            if self.id2lanes[i].length - s < distance_to_end:
425                lane_ids_for_count.append(i)
426        # count for the lane id
427        unique, counts = np.unique(lane_ids_for_count, return_counts=True)
428        return dict(zip(unique, counts))

Get the number of vehicles of each lane with speed lower than speed_threshold and distance to end lower than distance_to_end

Returns:

  • Dict: lane id -> number of vehicles
def get_lane_ids(self) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]:
430    def get_lane_ids(self) -> NDArray[np.int32]:
431        """
432        Get the ids of the lanes as a numpy array
433        """
434        return self.lane_index2id

Get the ids of the lanes as a numpy array

def get_lane_average_vehicle_speed(self, lane_index: int) -> float:
436    def get_lane_average_vehicle_speed(self, lane_index: int) -> float:
437        """
438        Get the average speed of the vehicles on the lane `lane_index`
439        """
440        if self.speed_stat_interval == 0:
441            raise RuntimeError(
442                "Please set speed_stat_interval to enable speed statistics"
443            )
444        lanes = self.fetch_lanes()
445        v_args: NDArray[np.float32] = lanes["v_avg"]
446        return v_args[lane_index].item()

Get the average speed of the vehicles on the lane lane_index

def get_junction_ids(self) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]:
448    def get_junction_ids(self) -> NDArray[np.int32]:
449        """
450        Get the ids of the junctions
451        """
452        return self.junc_index2id

Get the ids of the junctions

def get_junction_phase_lanes(self) -> List[List[Tuple[List[int], List[int]]]]:
454    def get_junction_phase_lanes(self) -> List[List[Tuple[List[int], List[int]]]]:
455        """
456        Get the `index` of the `in` and `out` lanes of each phase of each junction
457
458        Examples: TODO
459        """
460        return self._e.get_junction_phase_lanes()

Get the index of the in and out lanes of each phase of each junction

Examples: TODO

def get_junction_phase_ids(self) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]:
462    def get_junction_phase_ids(self) -> NDArray[np.int32]:
463        """
464        Get the phase id of each junction, `-1` if it has no traffic lights.
465        The junction id of the entry `i` can be obtained by `e.junc_index2id[i]`.
466        """
467        return self._e.get_junction_phase_ids()

Get the phase id of each junction, -1 if it has no traffic lights. The junction id of the entry i can be obtained by e.junc_index2id[i].

def get_junction_phase_counts(self) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.int32]]:
469    def get_junction_phase_counts(self) -> NDArray[np.int32]:
470        """
471        Get the number of available phases of each junction.
472        The junction id of the entry `i` can be obtained by `e.junc_index2id[i]`.
473        """
474        return self._e.get_junction_phase_counts()

Get the number of available phases of each junction. The junction id of the entry i can be obtained by e.junc_index2id[i].

def get_junction_dynamic_roads(self) -> List[List[int]]:
476    def get_junction_dynamic_roads(self) -> List[List[int]]:
477        """
478        Get the ids of the dynamic roads connected to each junction.
479        The junction id of the entry `i` can be obtained by `e.junc_index2id
480        """
481        return self._e.get_junction_dynamic_roads()

Get the ids of the dynamic roads connected to each junction. The junction id of the entry i can be obtained by `e.junc_index2id

def get_road_lane_plans(self, road_index: int) -> List[List[slice]]:
483    def get_road_lane_plans(self, road_index: int) -> List[List[slice]]:
484        """
485        Get the dynamic lane plan of the road `road_index`,
486        represented as list of lane groups:
487        ```
488        [
489            [slice(lane_start, lane_end), ...]
490        ]
491        ```
492        """
493        return [
494            [slice(a, b) for a, b in i] for i in self._e.get_road_lane_plans(road_index)
495        ]

Get the dynamic lane plan of the road road_index, represented as list of lane groups:

[
    [slice(lane_start, lane_end), ...]
]
def get_road_average_vehicle_speed(self, road_index: int) -> float:
497    def get_road_average_vehicle_speed(self, road_index: int) -> float:
498        """
499        Get the average speed of the vehicles on the road `road_index`
500        """
501        if self.speed_stat_interval == 0:
502            raise RuntimeError(
503                "Please set speed_stat_interval to enable speed statistics"
504            )
505        lanes = self.fetch_lanes()
506        road_id = self.road_index2id[road_index]
507        lane_ids = self.id2roads[road_id].lane_ids
508        lane_indexes = [self.lane_id2index[i] for i in lane_ids]
509        v_args: NDArray[np.float32] = lanes["v_avg"]
510        return v_args[lane_indexes].mean().item()

Get the average speed of the vehicles on the road road_index

def get_finished_person_count(self) -> int:
512    def get_finished_person_count(self) -> int:
513        """
514        Get the number of the finished persons
515        """
516        persons = self.fetch_persons()
517        status: NDArray[np.uint8] = persons["status"]
518        return (status == PersonStatus.FINISHED.value).sum()

Get the number of the finished persons

def get_finished_person_average_traveling_time(self) -> float:
520    def get_finished_person_average_traveling_time(self) -> float:
521        """
522        Get the average traveling time of the finished persons
523        """
524        persons = self.fetch_persons()
525        status: NDArray[np.uint8] = persons["status"]
526        traveling_time = persons["traveling_time"]
527        return traveling_time[status == PersonStatus.FINISHED.value].mean()

Get the average traveling time of the finished persons

def get_running_person_average_traveling_time(self) -> float:
529    def get_running_person_average_traveling_time(self) -> float:
530        """
531        Get the average traveling time of the running persons
532        """
533        persons = self.fetch_persons()
534        status: NDArray[np.uint8] = persons["status"]
535        traveling_time = persons["traveling_time"]
536        return traveling_time[status == PersonStatus.DRIVING.value].mean()

Get the average traveling time of the running persons

def get_departed_person_average_traveling_time(self) -> float:
538    def get_departed_person_average_traveling_time(self) -> float:
539        """
540        Get the average traveling time of the departed persons (running+finished)
541        """
542        persons = self.fetch_persons()
543        status: NDArray[np.uint8] = persons["status"]
544        traveling_time = persons["traveling_time"]
545        return traveling_time[status != PersonStatus.SLEEP.value].mean()

Get the average traveling time of the departed persons (running+finished)

def get_road_lane_plan_index(self, road_index: int) -> int:
547    def get_road_lane_plan_index(self, road_index: int) -> int:
548        """
549        Get the lane plan of road `road_index`
550        """
551        return self._e.get_road_lane_plan_index(road_index)

Get the lane plan of road road_index

def get_road_vehicle_counts(self) -> Dict[int, int]:
553    def get_road_vehicle_counts(self) -> Dict[int, int]:
554        """
555        Get the number of vehicles of each road
556
557        Returns:
558        - Dict: road id -> number of vehicles
559        """
560        persons = self.fetch_persons()
561        road_id = persons["lane_parent_id"]
562        status = persons["status"]
563        filter = status == PersonStatus.DRIVING.value
564        filtered_road_id = road_id[filter]
565        # count for the road id
566        unique, counts = np.unique(filtered_road_id, return_counts=True)
567        return dict(zip(unique, counts))

Get the number of vehicles of each road

Returns:

  • Dict: road id -> number of vehicles
def get_road_waiting_vehicle_counts(self, speed_threshold: float = 0.1) -> Dict[int, int]:
569    def get_road_waiting_vehicle_counts(
570        self, speed_threshold: float = 0.1
571    ) -> Dict[int, int]:
572        """
573        Get the number of vehicles with speed lower than `speed_threshold` of each road
574
575        Returns:
576        - Dict: road id -> number of vehicles
577        """
578
579        persons = self.fetch_persons()
580        road_id = persons["lane_parent_id"]
581        status = persons["status"]
582        v = persons["v"]
583        filter = (
584            (status == PersonStatus.DRIVING.value)
585            & (v < speed_threshold)
586            & (road_id < 3_0000_0000)  # the road id ranges [2_0000_0000, 3_0000_0000)
587        )
588        filtered_road_id = road_id[filter]
589        # count for the road id
590        unique, counts = np.unique(filtered_road_id, return_counts=True)
591        return dict(zip(unique, counts))

Get the number of vehicles with speed lower than speed_threshold of each road

Returns:

  • Dict: road id -> number of vehicles
def set_tl_policy(self, junction_index: int, policy: TlPolicy):
593    def set_tl_policy(self, junction_index: int, policy: TlPolicy):
594        """
595        Set the traffic light policy of junction `junction_index` to `policy`
596        """
597        self._e.set_tl_policy(junction_index, policy.value)

Set the traffic light policy of junction junction_index to policy

def set_tl_policy_batch(self, junction_indices: List[int], policy: TlPolicy):
599    def set_tl_policy_batch(self, junction_indices: List[int], policy: TlPolicy):
600        """
601        Set the traffic light policy of all junctions in `junction_indices` to `policy`
602        """
603        self._e.set_tl_policy_batch(junction_indices, policy.value)

Set the traffic light policy of all junctions in junction_indices to policy

def set_tl_duration(self, junction_index: int, duration: int):
605    def set_tl_duration(self, junction_index: int, duration: int):
606        """
607        Set the traffic light switch duration of junction `junction_index` to `duration`
608
609        NOTE: This is only effective for `TlPolicy.FIXED_TIME` and `TlPolicy.MAX_PRESSURE`.
610
611        NOTE: Set duration to `0` to use the predefined duration in the `map_file`
612        """
613        self._e.set_tl_duration(junction_index, duration)

Set the traffic light switch duration of junction junction_index to duration

NOTE: This is only effective for TlPolicy.FIXED_TIME and TlPolicy.MAX_PRESSURE.

NOTE: Set duration to 0 to use the predefined duration in the map_file

def set_tl_duration_batch(self, junction_indices: List[int], duration: int):
615    def set_tl_duration_batch(self, junction_indices: List[int], duration: int):
616        """
617        Set the traffic light switch duration of all junctions in `junction_indices` to `duration`
618
619        NOTE: This is only effective for `TlPolicy.FIXED_TIME` and `TlPolicy.MAX_PRESSURE`
620
621        NOTE: Set duration to `0` to use the predefined duration in the `map_file`
622        """
623        self._e.set_tl_duration_batch(junction_indices, duration)

Set the traffic light switch duration of all junctions in junction_indices to duration

NOTE: This is only effective for TlPolicy.FIXED_TIME and TlPolicy.MAX_PRESSURE

NOTE: Set duration to 0 to use the predefined duration in the map_file

def set_tl_phase(self, junction_index: Union[str, int], phase_index: int):
625    def set_tl_phase(self, junction_index: Union[str, int], phase_index: int):
626        """
627        Set the phase of `junction_index` to `phase_index`
628        """
629        self._e.set_tl_phase(junction_index, phase_index)

Set the phase of junction_index to phase_index

def set_tl_phase_batch(self, junction_indices: List[int], phase_indices: List[int]):
631    def set_tl_phase_batch(self, junction_indices: List[int], phase_indices: List[int]):
632        """
633        Set the phase of `junction_index` to `phase_index` in batch
634        """
635        assert len(junction_indices) == len(phase_indices)
636        self._e.set_tl_phase_batch(junction_indices, phase_indices)

Set the phase of junction_index to phase_index in batch

def set_road_lane_plan(self, road_index: int, plan_index: int):
638    def set_road_lane_plan(self, road_index: int, plan_index: int):
639        """
640        Set the lane plan of road `road_index`
641        """
642        self._e.set_road_lane_plan(road_index, plan_index)

Set the lane plan of road road_index

def set_road_lane_plan_batch(self, road_indices: List[int], plan_indices: List[int]):
644    def set_road_lane_plan_batch(
645        self, road_indices: List[int], plan_indices: List[int]
646    ):
647        """
648        Set the lane plan of road `road_index`
649        """
650        assert len(road_indices) == len(plan_indices)
651        self._e.set_road_lane_plan_batch(road_indices, plan_indices)

Set the lane plan of road road_index

def set_lane_restriction(self, lane_index: int, flag: bool):
653    def set_lane_restriction(self, lane_index: int, flag: bool):
654        """
655        Set the restriction state of lane `lane_index`
656        """
657        self._e.set_lane_restriction(lane_index, flag)

Set the restriction state of lane lane_index

def set_lane_restriction_batch(self, lane_indices: List[int], flags: List[bool]):
659    def set_lane_restriction_batch(self, lane_indices: List[int], flags: List[bool]):
660        """
661        Set the restriction state of lane `lane_index`
662        """
663        assert len(lane_indices) == len(flags)
664        self._e.set_lane_restriction_batch(lane_indices, flags)

Set the restriction state of lane lane_index

def set_lane_max_speed(self, lane_index: int, max_speed: float):
666    def set_lane_max_speed(self, lane_index: int, max_speed: float):
667        """
668        Set the max_speed of lane `lane_index`
669        """
670        self._e.set_lane_max_speed(lane_index, max_speed)

Set the max_speed of lane lane_index

def set_lane_max_speed_batch(self, lane_indices: List[int], max_speeds: Union[float, List[float]]):
672    def set_lane_max_speed_batch(
673        self, lane_indices: List[int], max_speeds: Union[float, List[float]]
674    ):
675        """
676        Set the max_speed of lane `lane_index`
677        """
678        if hasattr(max_speeds, "__len__"):
679            assert len(lane_indices) == len(max_speeds)
680        else:
681            max_speeds = [max_speeds] * len(lane_indices)
682        self._e.set_lane_max_speed_batch(lane_indices, max_speeds)

Set the max_speed of lane lane_index

def next_step(self, n=1):
684    def next_step(self, n=1):
685        """
686        Move forward `n` steps
687        """
688        self._fetched_persons = None
689        self._fetched_lanes = None
690        self._e.next_step(n)

Move forward n steps

class TlPolicy(enum.Enum):
34class TlPolicy(Enum):
35    MANUAL = 0
36    FIXED_TIME = 1
37    MAX_PRESSURE = 2
38    NONE = 3

An enumeration.

MANUAL = <TlPolicy.MANUAL: 0>
FIXED_TIME = <TlPolicy.FIXED_TIME: 1>
MAX_PRESSURE = <TlPolicy.MAX_PRESSURE: 2>
NONE = <TlPolicy.NONE: 3>
class Verbosity(enum.Enum):
41class Verbosity(Enum):
42    NO_OUTPUT = 0
43    INIT_ONLY = 1
44    ALL = 2

An enumeration.

NO_OUTPUT = <Verbosity.NO_OUTPUT: 0>
INIT_ONLY = <Verbosity.INIT_ONLY: 1>
ALL = <Verbosity.ALL: 2>
class DBRecorder:
 49class DBRecorder:
 50    """
 51    DBRecorder is for web visualization and writes to Postgres Database
 52
 53    The table schema is as follows:
 54    - meta_simple: The metadata of the simulation.
 55        - name: The name of the simulation.
 56        - start: The start step of the simulation.
 57        - steps: The total steps of the simulation.
 58        - time: The time of the simulation.
 59        - total_agents: The total agents of the simulation.
 60        - map: The map of the simulation.
 61        - min_lng: The minimum longitude of the simulation.
 62        - min_lat: The minimum latitude of the simulation.
 63        - max_lng: The maximum longitude of the simulation.
 64        - max_lat: The maximum latitude of the simulation.
 65        - road_status_v_min: The minimum speed of the road status.
 66        - road_status_interval: The interval of the road status.
 67    - {output_name}_s_cars: The vehicles of the simulation.
 68        - step: The step of the simulation.
 69        - id: The id of the vehicle.
 70        - parent_id: The parent id of the vehicle.
 71        - direction: The direction of the vehicle.
 72        - lng: The longitude of the vehicle.
 73        - lat: The latitude of the vehicle.
 74        - model: The model of the vehicle.
 75        - z: The z of the vehicle.
 76        - pitch: The pitch of the vehicle.
 77        - v: The speed of the vehicle.
 78    - {output_name}_s_people: The people of the simulation.
 79        - step: The step of the simulation.
 80        - id: The id of the people.
 81        - parent_id: The parent id of the people.
 82        - direction: The direction of the people.
 83        - lng: The longitude of the people.
 84        - lat: The latitude of the people.
 85        - z: The z of the people.
 86        - v: The speed of the people.
 87        - model: The model of the people.
 88    - {output_name}_s_traffic_light: The traffic lights of the simulation.
 89        - step: The step of the simulation.
 90        - id: The id of the traffic light.
 91        - state: The state of the traffic light.
 92        - lng: The longitude of the traffic light.
 93        - lat: The latitude of the traffic light.
 94    - {output_name}_s_road: The road status of the simulation.
 95        - step: The step of the simulation.
 96        - id: The id of the road.
 97        - level: The level of the road.
 98        - v: The speed of the road.
 99        - in_vehicle_cnt: The in vehicle count of the road.
100        - out_vehicle_cnt: The out vehicle count of the road.
101        - cnt: The count of the road.
102
103    The index of the table is as follows:
104    - {output_name}_s_cars: (step, lng, lat)
105    - {output_name}_s_people: (step, lng, lat)
106    - {output_name}_s_traffic_light: (step, lng, lat)
107    - {output_name}_s_road: (step)
108    """
109
110    def __init__(self, eng: Engine):
111        """
112        Args:
113        - eng: The engine to be recorded.
114        """
115        self.eng = eng
116        self.data = []
117
118    def record(self):
119        """
120        Record the data of the engine.
121        """
122        self.data.append([
123            self.eng._e.get_current_step(),
124            self.eng._e.get_output_vehicles(),
125            self.eng._e.get_output_tls(),
126        ])
127
128    def save(self, db_url: str, mongo_map: str, output_name: str, use_tqdm=False):
129        """
130        Save the data to the Postgres Database.
131
132        Args
133        - db_url: The URL of the Postgres Database.
134        - mongo_map: The map path of the simulation in mongodb (if you use mongodb). The format is like {db}.{coll}.
135        - output_name: The name of the simulation that will be saved to the database.
136        - use_tqdm: Whether to use tqdm or not.
137        """
138        vehs = []
139        tls = []
140        xs = []
141        ys = []
142        proj = pyproj.Proj(self.eng._map.header.projection)
143        for step, (vs, vx, vy), (ts, tx, ty) in self.data:
144            if vs:
145                x, y = proj(vx, vy, True)
146                xs.extend(x)
147                ys.extend(y)
148                vehs.append([step, vs, x, y])
149            if ts:
150                x, y = proj(tx, ty, True)
151                xs.extend(x)
152                ys.extend(y)
153                tls.append([step, ts, x, y])
154        if xs:
155            min_lon, max_lon, min_lat, max_lat = min(xs), max(xs), min(ys), max(ys)
156        else:
157            x1, y1, x2, y2 = self.eng._map_bbox
158            min_lon,  min_lat = proj(x1, y1, True)
159            max_lon,  max_lat = proj(x2, y2, True)
160        with psycopg2.connect(db_url) as conn:
161            with conn.cursor() as cur:
162                # create table meta_simple
163                cur.execute("""
164                CREATE TABLE IF NOT EXISTS public.meta_simple (
165                    "name" text NOT NULL,
166                    "start" int4 NOT NULL,
167                    steps int4 NOT NULL,
168                    "time" float8 NOT NULL,
169                    total_agents int4 NOT NULL,
170                    "map" text NOT NULL,
171                    min_lng float8 NOT NULL,
172                    min_lat float8 NOT NULL,
173                    max_lng float8 NOT NULL,
174                    max_lat float8 NOT NULL,
175                    road_status_v_min float8 NULL,
176                    road_status_interval int4 NULL,
177                    CONSTRAINT meta_simple_pkey PRIMARY KEY (name)
178                );
179                """)
180                conn.commit()
181
182                # 删除指定记录
183                # delete from public.meta_simple where name='output_name';
184                cur.execute(f"DELETE FROM public.meta_simple WHERE name='{output_name}';")
185                conn.commit()
186
187                # 插入新记录
188                # insert into public.meta_simple values ('output_name', 0, 1000, 1, 1, 'map', 0, 0, 1, 1, 0, 300);
189                cur.execute(
190                    f"INSERT INTO public.meta_simple VALUES ('{output_name}', {self.eng.start_step}, {len(self.data)}, 1, 1, '{mongo_map}', {min_lon}, {min_lat}, {max_lon}, {max_lat}, 0, 300);")
191                conn.commit()
192
193                # 删除表格
194                # drop table if exists public.output_name_s_cars;
195                cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_cars;")
196                cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_people;")
197                cur.execute(
198                    f"DROP TABLE IF EXISTS public.{output_name}_s_traffic_light;"
199                )
200                cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_road;")
201                conn.commit()
202
203                # 创建表格
204                # create table public.output_name_s_cars
205                cur.execute(
206                    f"""
207                CREATE TABLE public.{output_name}_s_cars (
208                    step int4 NOT NULL,
209                    id int4 NOT NULL,
210                    parent_id int4 NOT NULL,
211                    direction float8 NOT NULL,
212                    lng float8 NOT NULL,
213                    lat float8 NOT NULL,
214                    model text NOT NULL,
215                    z float8 NOT NULL,
216                    pitch float8 NOT NULL,
217                    v float8 NOT NULL
218                );
219                """
220                )
221                cur.execute(
222                    f"CREATE INDEX {output_name}_s_cars_step_lng_lat_idx ON public.{output_name}_s_cars USING btree (step, lng, lat);"
223                )
224                conn.commit()
225
226                # 创建表格
227                # create table public.output_name_s_people
228                cur.execute(
229                    f"""
230                CREATE TABLE public.{output_name}_s_people (
231                    step int4 NOT NULL,
232                    id int4 NOT NULL,
233                    parent_id int4 NOT NULL,
234                    direction float8 NOT NULL,
235                    lng float8 NOT NULL,
236                    lat float8 NOT NULL,
237                    z float8 NOT NULL,
238                    v float8 NOT NULL,
239                    model text NOT NULL
240                );
241                """
242                )
243                cur.execute(
244                    f"CREATE INDEX {output_name}_s_people_step_lng_lat_idx ON public.{output_name}_s_people USING btree (step, lng, lat);"
245                )
246                conn.commit()
247
248                # 创建表格
249                # create table public.output_name_s_traffic_light
250                cur.execute(
251                    f"""
252                CREATE TABLE public.{output_name}_s_traffic_light (
253                    step int4 NOT NULL,
254                    id int4 NOT NULL,
255                    state int4 NOT NULL,
256                    lng float8 NOT NULL,
257                    lat float8 NOT NULL
258                );
259                """
260                )
261                cur.execute(
262                    f"CREATE INDEX {output_name}_s_traffic_light_step_lng_lat_idx ON public.{output_name}_s_traffic_light USING btree (step, lng, lat);"
263                )
264                conn.commit()
265
266                # 创建表格
267                # create table public.output_name_s_road
268                cur.execute(
269                    f"""
270                CREATE TABLE public.{output_name}_s_road (
271                    step int4 NOT NULL,
272                    id int4 NOT NULL,
273                    "level" int4 NOT NULL,
274                    v float8 NOT NULL,
275                    in_vehicle_cnt int4 NOT NULL,
276                    out_vehicle_cnt int4 NOT NULL,
277                    cnt int4 NOT NULL
278                );
279                """
280                )
281                cur.execute(
282                    f"CREATE INDEX {output_name}_s_road_step_idx ON public.{output_name}_s_road USING btree (step);"
283                )
284                conn.commit()
285
286                cur.copy_from(
287                    StringIteratorIO(
288                        f"{step},{p},{l},{round(d,3)},{x},{y},'',0,0,{round(v,3)}\n"
289                        for step, vs, x, y in tqdm(vehs, ncols=90, disable=not use_tqdm)
290                        for (p, l, d, v), x, y in zip(vs, x, y)
291                    ),
292                    f"{output_name}_s_cars",
293                    sep=",",
294                )
295                cur.copy_from(
296                    StringIteratorIO(
297                        f"{step},{p},{s},{x},{y}\n"
298                        for step, ts, x, y in tqdm(tls, ncols=90, disable=not use_tqdm)
299                        for (p, s), x, y in zip(ts, x, y)
300                    ),
301                    f"{output_name}_s_traffic_light",
302                    sep=",",
303                )
304                conn.commit()

DBRecorder is for web visualization and writes to Postgres Database

The table schema is as follows:

  • meta_simple: The metadata of the simulation.
    • name: The name of the simulation.
    • start: The start step of the simulation.
    • steps: The total steps of the simulation.
    • time: The time of the simulation.
    • total_agents: The total agents of the simulation.
    • map: The map of the simulation.
    • min_lng: The minimum longitude of the simulation.
    • min_lat: The minimum latitude of the simulation.
    • max_lng: The maximum longitude of the simulation.
    • max_lat: The maximum latitude of the simulation.
    • road_status_v_min: The minimum speed of the road status.
    • road_status_interval: The interval of the road status.
  • {output_name}_s_cars: The vehicles of the simulation.
    • step: The step of the simulation.
    • id: The id of the vehicle.
    • parent_id: The parent id of the vehicle.
    • direction: The direction of the vehicle.
    • lng: The longitude of the vehicle.
    • lat: The latitude of the vehicle.
    • model: The model of the vehicle.
    • z: The z of the vehicle.
    • pitch: The pitch of the vehicle.
    • v: The speed of the vehicle.
  • {output_name}_s_people: The people of the simulation.
    • step: The step of the simulation.
    • id: The id of the people.
    • parent_id: The parent id of the people.
    • direction: The direction of the people.
    • lng: The longitude of the people.
    • lat: The latitude of the people.
    • z: The z of the people.
    • v: The speed of the people.
    • model: The model of the people.
  • {output_name}_s_traffic_light: The traffic lights of the simulation.
    • step: The step of the simulation.
    • id: The id of the traffic light.
    • state: The state of the traffic light.
    • lng: The longitude of the traffic light.
    • lat: The latitude of the traffic light.
  • {output_name}_s_road: The road status of the simulation.
    • step: The step of the simulation.
    • id: The id of the road.
    • level: The level of the road.
    • v: The speed of the road.
    • in_vehicle_cnt: The in vehicle count of the road.
    • out_vehicle_cnt: The out vehicle count of the road.
    • cnt: The count of the road.

The index of the table is as follows:

  • {output_name}_s_cars: (step, lng, lat)
  • {output_name}_s_people: (step, lng, lat)
  • {output_name}_s_traffic_light: (step, lng, lat)
  • {output_name}_s_road: (step)
DBRecorder(eng: Engine)
110    def __init__(self, eng: Engine):
111        """
112        Args:
113        - eng: The engine to be recorded.
114        """
115        self.eng = eng
116        self.data = []

Args:

  • eng: The engine to be recorded.
eng
data
def record(self):
118    def record(self):
119        """
120        Record the data of the engine.
121        """
122        self.data.append([
123            self.eng._e.get_current_step(),
124            self.eng._e.get_output_vehicles(),
125            self.eng._e.get_output_tls(),
126        ])

Record the data of the engine.

def save(self, db_url: str, mongo_map: str, output_name: str, use_tqdm=False):
128    def save(self, db_url: str, mongo_map: str, output_name: str, use_tqdm=False):
129        """
130        Save the data to the Postgres Database.
131
132        Args
133        - db_url: The URL of the Postgres Database.
134        - mongo_map: The map path of the simulation in mongodb (if you use mongodb). The format is like {db}.{coll}.
135        - output_name: The name of the simulation that will be saved to the database.
136        - use_tqdm: Whether to use tqdm or not.
137        """
138        vehs = []
139        tls = []
140        xs = []
141        ys = []
142        proj = pyproj.Proj(self.eng._map.header.projection)
143        for step, (vs, vx, vy), (ts, tx, ty) in self.data:
144            if vs:
145                x, y = proj(vx, vy, True)
146                xs.extend(x)
147                ys.extend(y)
148                vehs.append([step, vs, x, y])
149            if ts:
150                x, y = proj(tx, ty, True)
151                xs.extend(x)
152                ys.extend(y)
153                tls.append([step, ts, x, y])
154        if xs:
155            min_lon, max_lon, min_lat, max_lat = min(xs), max(xs), min(ys), max(ys)
156        else:
157            x1, y1, x2, y2 = self.eng._map_bbox
158            min_lon,  min_lat = proj(x1, y1, True)
159            max_lon,  max_lat = proj(x2, y2, True)
160        with psycopg2.connect(db_url) as conn:
161            with conn.cursor() as cur:
162                # create table meta_simple
163                cur.execute("""
164                CREATE TABLE IF NOT EXISTS public.meta_simple (
165                    "name" text NOT NULL,
166                    "start" int4 NOT NULL,
167                    steps int4 NOT NULL,
168                    "time" float8 NOT NULL,
169                    total_agents int4 NOT NULL,
170                    "map" text NOT NULL,
171                    min_lng float8 NOT NULL,
172                    min_lat float8 NOT NULL,
173                    max_lng float8 NOT NULL,
174                    max_lat float8 NOT NULL,
175                    road_status_v_min float8 NULL,
176                    road_status_interval int4 NULL,
177                    CONSTRAINT meta_simple_pkey PRIMARY KEY (name)
178                );
179                """)
180                conn.commit()
181
182                # 删除指定记录
183                # delete from public.meta_simple where name='output_name';
184                cur.execute(f"DELETE FROM public.meta_simple WHERE name='{output_name}';")
185                conn.commit()
186
187                # 插入新记录
188                # insert into public.meta_simple values ('output_name', 0, 1000, 1, 1, 'map', 0, 0, 1, 1, 0, 300);
189                cur.execute(
190                    f"INSERT INTO public.meta_simple VALUES ('{output_name}', {self.eng.start_step}, {len(self.data)}, 1, 1, '{mongo_map}', {min_lon}, {min_lat}, {max_lon}, {max_lat}, 0, 300);")
191                conn.commit()
192
193                # 删除表格
194                # drop table if exists public.output_name_s_cars;
195                cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_cars;")
196                cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_people;")
197                cur.execute(
198                    f"DROP TABLE IF EXISTS public.{output_name}_s_traffic_light;"
199                )
200                cur.execute(f"DROP TABLE IF EXISTS public.{output_name}_s_road;")
201                conn.commit()
202
203                # 创建表格
204                # create table public.output_name_s_cars
205                cur.execute(
206                    f"""
207                CREATE TABLE public.{output_name}_s_cars (
208                    step int4 NOT NULL,
209                    id int4 NOT NULL,
210                    parent_id int4 NOT NULL,
211                    direction float8 NOT NULL,
212                    lng float8 NOT NULL,
213                    lat float8 NOT NULL,
214                    model text NOT NULL,
215                    z float8 NOT NULL,
216                    pitch float8 NOT NULL,
217                    v float8 NOT NULL
218                );
219                """
220                )
221                cur.execute(
222                    f"CREATE INDEX {output_name}_s_cars_step_lng_lat_idx ON public.{output_name}_s_cars USING btree (step, lng, lat);"
223                )
224                conn.commit()
225
226                # 创建表格
227                # create table public.output_name_s_people
228                cur.execute(
229                    f"""
230                CREATE TABLE public.{output_name}_s_people (
231                    step int4 NOT NULL,
232                    id int4 NOT NULL,
233                    parent_id int4 NOT NULL,
234                    direction float8 NOT NULL,
235                    lng float8 NOT NULL,
236                    lat float8 NOT NULL,
237                    z float8 NOT NULL,
238                    v float8 NOT NULL,
239                    model text NOT NULL
240                );
241                """
242                )
243                cur.execute(
244                    f"CREATE INDEX {output_name}_s_people_step_lng_lat_idx ON public.{output_name}_s_people USING btree (step, lng, lat);"
245                )
246                conn.commit()
247
248                # 创建表格
249                # create table public.output_name_s_traffic_light
250                cur.execute(
251                    f"""
252                CREATE TABLE public.{output_name}_s_traffic_light (
253                    step int4 NOT NULL,
254                    id int4 NOT NULL,
255                    state int4 NOT NULL,
256                    lng float8 NOT NULL,
257                    lat float8 NOT NULL
258                );
259                """
260                )
261                cur.execute(
262                    f"CREATE INDEX {output_name}_s_traffic_light_step_lng_lat_idx ON public.{output_name}_s_traffic_light USING btree (step, lng, lat);"
263                )
264                conn.commit()
265
266                # 创建表格
267                # create table public.output_name_s_road
268                cur.execute(
269                    f"""
270                CREATE TABLE public.{output_name}_s_road (
271                    step int4 NOT NULL,
272                    id int4 NOT NULL,
273                    "level" int4 NOT NULL,
274                    v float8 NOT NULL,
275                    in_vehicle_cnt int4 NOT NULL,
276                    out_vehicle_cnt int4 NOT NULL,
277                    cnt int4 NOT NULL
278                );
279                """
280                )
281                cur.execute(
282                    f"CREATE INDEX {output_name}_s_road_step_idx ON public.{output_name}_s_road USING btree (step);"
283                )
284                conn.commit()
285
286                cur.copy_from(
287                    StringIteratorIO(
288                        f"{step},{p},{l},{round(d,3)},{x},{y},'',0,0,{round(v,3)}\n"
289                        for step, vs, x, y in tqdm(vehs, ncols=90, disable=not use_tqdm)
290                        for (p, l, d, v), x, y in zip(vs, x, y)
291                    ),
292                    f"{output_name}_s_cars",
293                    sep=",",
294                )
295                cur.copy_from(
296                    StringIteratorIO(
297                        f"{step},{p},{s},{x},{y}\n"
298                        for step, ts, x, y in tqdm(tls, ncols=90, disable=not use_tqdm)
299                        for (p, s), x, y in zip(ts, x, y)
300                    ),
301                    f"{output_name}_s_traffic_light",
302                    sep=",",
303                )
304                conn.commit()

Save the data to the Postgres Database.

Args

  • db_url: The URL of the Postgres Database.
  • mongo_map: The map path of the simulation in mongodb (if you use mongodb). The format is like {db}.{coll}.
  • output_name: The name of the simulation that will be saved to the database.
  • use_tqdm: Whether to use tqdm or not.