|
| 1 | +""" |
| 2 | +Safe interval path planner |
| 3 | + This script implements a safe-interval path planner for a 2d grid with dynamic obstacles. It is faster than |
| 4 | + SpaceTime A* because it reduces the number of redundant node expansions by pre-computing regions of adjacent |
| 5 | + time steps that are safe ("safe intervals") at each position. This allows the algorithm to skip expanding nodes |
| 6 | + that are in intervals that have already been visited earlier. |
| 7 | +
|
| 8 | + Reference: https://www.cs.cmu.edu/~maxim/files/sipp_icra11.pdf |
| 9 | +""" |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +from PathPlanning.TimeBasedPathPlanning.GridWithDynamicObstacles import ( |
| 14 | + Grid, |
| 15 | + Interval, |
| 16 | + ObstacleArrangement, |
| 17 | + Position, |
| 18 | + empty_2d_array_of_lists, |
| 19 | +) |
| 20 | +import heapq |
| 21 | +import random |
| 22 | +from dataclasses import dataclass |
| 23 | +from functools import total_ordering |
| 24 | +import time |
| 25 | + |
| 26 | +@dataclass() |
| 27 | +# Note: Total_ordering is used instead of adding `order=True` to the @dataclass decorator because |
| 28 | +# this class needs to override the __lt__ and __eq__ methods to ignore parent_index. The Parent |
| 29 | +# index and interval member variables are just used to track the path found by the algorithm, |
| 30 | +# and has no effect on the quality of a node. |
| 31 | +@total_ordering |
| 32 | +class Node: |
| 33 | + position: Position |
| 34 | + time: int |
| 35 | + heuristic: int |
| 36 | + parent_index: int |
| 37 | + interval: Interval |
| 38 | + |
| 39 | + """ |
| 40 | + This is what is used to drive node expansion. The node with the lowest value is expanded next. |
| 41 | + This comparison prioritizes the node with the lowest cost-to-come (self.time) + cost-to-go (self.heuristic) |
| 42 | + """ |
| 43 | + def __lt__(self, other: object): |
| 44 | + if not isinstance(other, Node): |
| 45 | + return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}") |
| 46 | + return (self.time + self.heuristic) < (other.time + other.heuristic) |
| 47 | + |
| 48 | + """ |
| 49 | + Equality only cares about position and time. Heuristic and interval will always be the same for a given |
| 50 | + (position, time) pairing, so they are not considered in equality. |
| 51 | + """ |
| 52 | + def __eq__(self, other: object): |
| 53 | + if not isinstance(other, Node): |
| 54 | + return NotImplemented |
| 55 | + return self.position == other.position and self.time == other.time |
| 56 | + |
| 57 | +@dataclass |
| 58 | +class EntryTimeAndInterval: |
| 59 | + entry_time: int |
| 60 | + interval: Interval |
| 61 | + |
| 62 | +class NodePath: |
| 63 | + path: list[Node] |
| 64 | + positions_at_time: dict[int, Position] = {} |
| 65 | + |
| 66 | + def __init__(self, path: list[Node]): |
| 67 | + self.path = path |
| 68 | + for (i, node) in enumerate(path): |
| 69 | + if i > 0: |
| 70 | + # account for waiting in interval at previous node |
| 71 | + prev_node = path[i-1] |
| 72 | + for t in range(prev_node.time, node.time): |
| 73 | + self.positions_at_time[t] = prev_node.position |
| 74 | + |
| 75 | + self.positions_at_time[node.time] = node.position |
| 76 | + |
| 77 | + """ |
| 78 | + Get the position of the path at a given time |
| 79 | + """ |
| 80 | + def get_position(self, time: int) -> Position | None: |
| 81 | + return self.positions_at_time.get(time) |
| 82 | + |
| 83 | + """ |
| 84 | + Time stamp of the last node in the path |
| 85 | + """ |
| 86 | + def goal_reached_time(self) -> int: |
| 87 | + return self.path[-1].time |
| 88 | + |
| 89 | + def __repr__(self): |
| 90 | + repr_string = "" |
| 91 | + for i, node in enumerate(self.path): |
| 92 | + repr_string += f"{i}: {node}\n" |
| 93 | + return repr_string |
| 94 | + |
| 95 | + |
| 96 | +class SafeIntervalPathPlanner: |
| 97 | + grid: Grid |
| 98 | + start: Position |
| 99 | + goal: Position |
| 100 | + |
| 101 | + def __init__(self, grid: Grid, start: Position, goal: Position): |
| 102 | + self.grid = grid |
| 103 | + self.start = start |
| 104 | + self.goal = goal |
| 105 | + |
| 106 | + # Seed randomness for reproducibility |
| 107 | + RANDOM_SEED = 50 |
| 108 | + random.seed(RANDOM_SEED) |
| 109 | + np.random.seed(RANDOM_SEED) |
| 110 | + |
| 111 | + """ |
| 112 | + Generate a plan given the loaded problem statement. Raises an exception if it fails to find a path. |
| 113 | + Arguments: |
| 114 | + verbose (bool): set to True to print debug information |
| 115 | + """ |
| 116 | + def plan(self, verbose: bool = False) -> NodePath: |
| 117 | + |
| 118 | + safe_intervals = self.grid.get_safe_intervals() |
| 119 | + |
| 120 | + open_set: list[Node] = [] |
| 121 | + first_node_interval = safe_intervals[self.start.x, self.start.y][0] |
| 122 | + heapq.heappush( |
| 123 | + open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1, first_node_interval) |
| 124 | + ) |
| 125 | + |
| 126 | + expanded_list: list[Node] = [] |
| 127 | + visited_intervals = empty_2d_array_of_lists(self.grid.grid_size[0], self.grid.grid_size[1]) |
| 128 | + while open_set: |
| 129 | + expanded_node: Node = heapq.heappop(open_set) |
| 130 | + if verbose: |
| 131 | + print("Expanded node:", expanded_node) |
| 132 | + |
| 133 | + if expanded_node.time + 1 >= self.grid.time_limit: |
| 134 | + if verbose: |
| 135 | + print(f"\tSkipping node that is past time limit: {expanded_node}") |
| 136 | + continue |
| 137 | + |
| 138 | + if expanded_node.position == self.goal: |
| 139 | + print(f"Found path to goal after {len(expanded_list)} expansions") |
| 140 | + path = [] |
| 141 | + path_walker: Node = expanded_node |
| 142 | + while True: |
| 143 | + path.append(path_walker) |
| 144 | + if path_walker.parent_index == -1: |
| 145 | + break |
| 146 | + path_walker = expanded_list[path_walker.parent_index] |
| 147 | + |
| 148 | + # reverse path so it goes start -> goal |
| 149 | + path.reverse() |
| 150 | + return NodePath(path) |
| 151 | + |
| 152 | + expanded_idx = len(expanded_list) |
| 153 | + expanded_list.append(expanded_node) |
| 154 | + entry_time_and_node = EntryTimeAndInterval(expanded_node.time, expanded_node.interval) |
| 155 | + add_entry_to_visited_intervals_array(entry_time_and_node, visited_intervals, expanded_node) |
| 156 | + |
| 157 | + for child in self.generate_successors(expanded_node, expanded_idx, safe_intervals, visited_intervals): |
| 158 | + heapq.heappush(open_set, child) |
| 159 | + |
| 160 | + raise Exception("No path found") |
| 161 | + |
| 162 | + """ |
| 163 | + Generate list of possible successors of the provided `parent_node` that are worth expanding |
| 164 | + """ |
| 165 | + def generate_successors( |
| 166 | + self, parent_node: Node, parent_node_idx: int, intervals: np.ndarray, visited_intervals: np.ndarray |
| 167 | + ) -> list[Node]: |
| 168 | + new_nodes = [] |
| 169 | + diffs = [ |
| 170 | + Position(0, 0), |
| 171 | + Position(1, 0), |
| 172 | + Position(-1, 0), |
| 173 | + Position(0, 1), |
| 174 | + Position(0, -1), |
| 175 | + ] |
| 176 | + for diff in diffs: |
| 177 | + new_pos = parent_node.position + diff |
| 178 | + if not self.grid.inside_grid_bounds(new_pos): |
| 179 | + continue |
| 180 | + |
| 181 | + current_interval = parent_node.interval |
| 182 | + |
| 183 | + new_cell_intervals: list[Interval] = intervals[new_pos.x, new_pos.y] |
| 184 | + for interval in new_cell_intervals: |
| 185 | + # if interval starts after current ends, break |
| 186 | + # assumption: intervals are sorted by start time, so all future intervals will hit this condition as well |
| 187 | + if interval.start_time > current_interval.end_time: |
| 188 | + break |
| 189 | + |
| 190 | + # if interval ends before current starts, skip |
| 191 | + if interval.end_time < current_interval.start_time: |
| 192 | + continue |
| 193 | + |
| 194 | + # if we have already expanded a node in this interval with a <= starting time, skip |
| 195 | + better_node_expanded = False |
| 196 | + for visited in visited_intervals[new_pos.x, new_pos.y]: |
| 197 | + if interval == visited.interval and visited.entry_time <= parent_node.time + 1: |
| 198 | + better_node_expanded = True |
| 199 | + break |
| 200 | + if better_node_expanded: |
| 201 | + continue |
| 202 | + |
| 203 | + # We know there is a node worth expanding. Generate successor at the earliest possible time the |
| 204 | + # new interval can be entered |
| 205 | + for possible_t in range(max(parent_node.time + 1, interval.start_time), min(current_interval.end_time, interval.end_time)): |
| 206 | + if self.grid.valid_position(new_pos, possible_t): |
| 207 | + new_nodes.append(Node( |
| 208 | + new_pos, |
| 209 | + # entry is max of interval start and parent node time + 1 (get there as soon as possible) |
| 210 | + max(interval.start_time, parent_node.time + 1), |
| 211 | + self.calculate_heuristic(new_pos), |
| 212 | + parent_node_idx, |
| 213 | + interval, |
| 214 | + )) |
| 215 | + # break because all t's after this will make nodes with a higher cost, the same heuristic, and are in the same interval |
| 216 | + break |
| 217 | + |
| 218 | + return new_nodes |
| 219 | + |
| 220 | + """ |
| 221 | + Calculate the heuristic for a given position - Manhattan distance to the goal |
| 222 | + """ |
| 223 | + def calculate_heuristic(self, position) -> int: |
| 224 | + diff = self.goal - position |
| 225 | + return abs(diff.x) + abs(diff.y) |
| 226 | + |
| 227 | + |
| 228 | +""" |
| 229 | +Adds a new entry to the visited intervals array. If the entry is already present, the entry time is updated if the new |
| 230 | +entry time is better. Otherwise, the entry is added to `visited_intervals` at the position of `expanded_node`. |
| 231 | +""" |
| 232 | +def add_entry_to_visited_intervals_array(entry_time_and_interval: EntryTimeAndInterval, visited_intervals: np.ndarray, expanded_node: Node): |
| 233 | + # if entry is present, update entry time if better |
| 234 | + for existing_entry_and_interval in visited_intervals[expanded_node.position.x, expanded_node.position.y]: |
| 235 | + if existing_entry_and_interval.interval == entry_time_and_interval.interval: |
| 236 | + existing_entry_and_interval.entry_time = min(existing_entry_and_interval.entry_time, entry_time_and_interval.entry_time) |
| 237 | + |
| 238 | + # Otherwise, append |
| 239 | + visited_intervals[expanded_node.position.x, expanded_node.position.y].append(entry_time_and_interval) |
| 240 | + |
| 241 | + |
| 242 | +show_animation = True |
| 243 | +verbose = False |
| 244 | + |
| 245 | +def main(): |
| 246 | + start = Position(1, 18) |
| 247 | + goal = Position(19, 19) |
| 248 | + grid_side_length = 21 |
| 249 | + |
| 250 | + start_time = time.time() |
| 251 | + |
| 252 | + grid = Grid( |
| 253 | + np.array([grid_side_length, grid_side_length]), |
| 254 | + num_obstacles=250, |
| 255 | + obstacle_avoid_points=[start, goal], |
| 256 | + obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1, |
| 257 | + # obstacle_arrangement=ObstacleArrangement.RANDOM, |
| 258 | + ) |
| 259 | + |
| 260 | + planner = SafeIntervalPathPlanner(grid, start, goal) |
| 261 | + path = planner.plan(verbose) |
| 262 | + runtime = time.time() - start_time |
| 263 | + print(f"Planning took: {runtime:.5f} seconds") |
| 264 | + |
| 265 | + if verbose: |
| 266 | + print(f"Path: {path}") |
| 267 | + |
| 268 | + if not show_animation: |
| 269 | + return |
| 270 | + |
| 271 | + fig = plt.figure(figsize=(10, 7)) |
| 272 | + ax = fig.add_subplot( |
| 273 | + autoscale_on=False, |
| 274 | + xlim=(0, grid.grid_size[0] - 1), |
| 275 | + ylim=(0, grid.grid_size[1] - 1), |
| 276 | + ) |
| 277 | + ax.set_aspect("equal") |
| 278 | + ax.grid() |
| 279 | + ax.set_xticks(np.arange(0, grid_side_length, 1)) |
| 280 | + ax.set_yticks(np.arange(0, grid_side_length, 1)) |
| 281 | + |
| 282 | + (start_and_goal,) = ax.plot([], [], "mD", ms=15, label="Start and Goal") |
| 283 | + start_and_goal.set_data([start.x, goal.x], [start.y, goal.y]) |
| 284 | + (obs_points,) = ax.plot([], [], "ro", ms=15, label="Obstacles") |
| 285 | + (path_points,) = ax.plot([], [], "bo", ms=10, label="Path Found") |
| 286 | + ax.legend(bbox_to_anchor=(1.05, 1)) |
| 287 | + |
| 288 | + # for stopping simulation with the esc key. |
| 289 | + plt.gcf().canvas.mpl_connect( |
| 290 | + "key_release_event", lambda event: [exit(0) if event.key == "escape" else None] |
| 291 | + ) |
| 292 | + |
| 293 | + for i in range(0, path.goal_reached_time() + 1): |
| 294 | + obs_positions = grid.get_obstacle_positions_at_time(i) |
| 295 | + obs_points.set_data(obs_positions[0], obs_positions[1]) |
| 296 | + path_position = path.get_position(i) |
| 297 | + path_points.set_data([path_position.x], [path_position.y]) |
| 298 | + plt.pause(0.2) |
| 299 | + plt.show() |
| 300 | + |
| 301 | + |
| 302 | +if __name__ == "__main__": |
| 303 | + main() |
0 commit comments