Skip to content

Commit aa61a6e

Browse files
SchwartzCodeCopilotAtsushiSakai
authored
Safe Interval Path Planner (#1184)
* it works and is WAY faster than a* * some bug fixes from testing different scenarios * add some docs & address todos * add sipp test * spiff up comments revert changes in speed-up * explain what the removal is doing * linting * fix docs build * docs formatting * revert change to file (maybe linter did it?) * point at gifs in gifs repo * use raw githubusercontent gif links * change formatting on planner results * format output differently * proper formatting final * missing underline * revert unintended change * grammar + add descriptions for gifs * missing :: * add title to gifs section * dont use sections for sub-sections * constent a* spelling * Update PathPlanning/TimeBasedPathPlanning/GridWithDynamicObstacles.py Co-authored-by: Copilot <[email protected]> * Update tests/test_safe_interval_path_planner.py Co-authored-by: Copilot <[email protected]> * Update docs/modules/5_path_planning/time_based_grid_search/time_based_grid_search_main.rst Co-authored-by: Atsushi Sakai <[email protected]> * Update PathPlanning/TimeBasedPathPlanning/SafeInterval.py Co-authored-by: Copilot <[email protected]> * addressing comments * revert np.full change --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Atsushi Sakai <[email protected]>
1 parent 41187d6 commit aa61a6e

File tree

4 files changed

+434
-2
lines changed

4 files changed

+434
-2
lines changed

PathPlanning/TimeBasedPathPlanning/GridWithDynamicObstacles.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,25 @@ def __sub__(self, other):
3333
def __hash__(self):
3434
return hash((self.x, self.y))
3535

36+
@dataclass
37+
class Interval:
38+
start_time: int
39+
end_time: int
3640

3741
class ObstacleArrangement(Enum):
3842
# Random obstacle positions and movements
3943
RANDOM = 0
4044
# Obstacles start in a line in y at center of grid and move side-to-side in x
4145
ARRANGEMENT1 = 1
4246

47+
"""
48+
Generates a 2d numpy array with lists for elements.
49+
"""
50+
def empty_2d_array_of_lists(x: int, y: int) -> np.ndarray:
51+
arr = np.empty((x, y), dtype=object)
52+
# assign each element individually - np.full creates references to the same list
53+
arr[:] = [[[] for _ in range(y)] for _ in range(x)]
54+
return arr
4355

4456
class Grid:
4557
# Set in constructor
@@ -89,7 +101,7 @@ def __init__(
89101
"""
90102
def generate_dynamic_obstacles(self, obs_count: int) -> list[list[Position]]:
91103
obstacle_paths = []
92-
for _ in (0, obs_count):
104+
for _ in range(0, obs_count):
93105
# Sample until a free starting space is found
94106
initial_position = self.sample_random_position()
95107
while not self.valid_obstacle_position(initial_position, 0):
@@ -234,6 +246,49 @@ def get_obstacle_positions_at_time(self, t: int) -> tuple[list[int], list[int]]:
234246
y_positions.append(obs_path[t].y)
235247
return (x_positions, y_positions)
236248

249+
"""
250+
Returns safe intervals for each cell.
251+
"""
252+
def get_safe_intervals(self) -> np.ndarray:
253+
intervals = empty_2d_array_of_lists(self.grid_size[0], self.grid_size[1])
254+
for x in range(intervals.shape[0]):
255+
for y in range(intervals.shape[1]):
256+
intervals[x, y] = self.get_safe_intervals_at_cell(Position(x, y))
257+
258+
return intervals
259+
260+
"""
261+
Generate the safe intervals for a given cell. The intervals will be in order of start time.
262+
ex: Interval (2, 3) will be before Interval (4, 5)
263+
"""
264+
def get_safe_intervals_at_cell(self, cell: Position) -> list[Interval]:
265+
vals = self.reservation_matrix[cell.x, cell.y, :]
266+
# Find where the array is zero
267+
zero_mask = (vals == 0)
268+
269+
# Identify transitions between zero and nonzero elements
270+
diff = np.diff(zero_mask.astype(int))
271+
272+
# Start indices: where zeros begin (1 after a nonzero)
273+
start_indices = np.where(diff == 1)[0] + 1
274+
275+
# End indices: where zeros stop (just before a nonzero)
276+
end_indices = np.where(diff == -1)[0]
277+
278+
# Handle edge cases if the array starts or ends with zeros
279+
if zero_mask[0]: # If the first element is zero, add index 0 to start_indices
280+
start_indices = np.insert(start_indices, 0, 0)
281+
if zero_mask[-1]: # If the last element is zero, add the last index to end_indices
282+
end_indices = np.append(end_indices, len(vals) - 1)
283+
284+
# Create pairs of (first zero, last zero)
285+
intervals = [Interval(int(start), int(end)) for start, end in zip(start_indices, end_indices)]
286+
287+
# Remove intervals where a cell is only free for one time step. Those intervals not provide enough time to
288+
# move into and out of the cell each take 1 time step, and the cell is considered occupied during
289+
# both the time step when it is entering the cell, and the time step when it is leaving the cell.
290+
intervals = [interval for interval in intervals if interval.start_time != interval.end_time]
291+
return intervals
237292

238293
show_animation = True
239294

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""
2+
Safe interval path planner
3+
This script implements a safe-interval path planner for a 2d grid with dynamic obstacles. It is faster than
4+
SpaceTime A* because it reduces the number of redundant node expansions by pre-computing regions of adjacent
5+
time steps that are safe ("safe intervals") at each position. This allows the algorithm to skip expanding nodes
6+
that are in intervals that have already been visited earlier.
7+
8+
Reference: https://www.cs.cmu.edu/~maxim/files/sipp_icra11.pdf
9+
"""
10+
11+
import numpy as np
12+
import matplotlib.pyplot as plt
13+
from PathPlanning.TimeBasedPathPlanning.GridWithDynamicObstacles import (
14+
Grid,
15+
Interval,
16+
ObstacleArrangement,
17+
Position,
18+
empty_2d_array_of_lists,
19+
)
20+
import heapq
21+
import random
22+
from dataclasses import dataclass
23+
from functools import total_ordering
24+
import time
25+
26+
@dataclass()
27+
# Note: Total_ordering is used instead of adding `order=True` to the @dataclass decorator because
28+
# this class needs to override the __lt__ and __eq__ methods to ignore parent_index. The Parent
29+
# index and interval member variables are just used to track the path found by the algorithm,
30+
# and has no effect on the quality of a node.
31+
@total_ordering
32+
class Node:
33+
position: Position
34+
time: int
35+
heuristic: int
36+
parent_index: int
37+
interval: Interval
38+
39+
"""
40+
This is what is used to drive node expansion. The node with the lowest value is expanded next.
41+
This comparison prioritizes the node with the lowest cost-to-come (self.time) + cost-to-go (self.heuristic)
42+
"""
43+
def __lt__(self, other: object):
44+
if not isinstance(other, Node):
45+
return NotImplementedError(f"Cannot compare Node with object of type: {type(other)}")
46+
return (self.time + self.heuristic) < (other.time + other.heuristic)
47+
48+
"""
49+
Equality only cares about position and time. Heuristic and interval will always be the same for a given
50+
(position, time) pairing, so they are not considered in equality.
51+
"""
52+
def __eq__(self, other: object):
53+
if not isinstance(other, Node):
54+
return NotImplemented
55+
return self.position == other.position and self.time == other.time
56+
57+
@dataclass
58+
class EntryTimeAndInterval:
59+
entry_time: int
60+
interval: Interval
61+
62+
class NodePath:
63+
path: list[Node]
64+
positions_at_time: dict[int, Position] = {}
65+
66+
def __init__(self, path: list[Node]):
67+
self.path = path
68+
for (i, node) in enumerate(path):
69+
if i > 0:
70+
# account for waiting in interval at previous node
71+
prev_node = path[i-1]
72+
for t in range(prev_node.time, node.time):
73+
self.positions_at_time[t] = prev_node.position
74+
75+
self.positions_at_time[node.time] = node.position
76+
77+
"""
78+
Get the position of the path at a given time
79+
"""
80+
def get_position(self, time: int) -> Position | None:
81+
return self.positions_at_time.get(time)
82+
83+
"""
84+
Time stamp of the last node in the path
85+
"""
86+
def goal_reached_time(self) -> int:
87+
return self.path[-1].time
88+
89+
def __repr__(self):
90+
repr_string = ""
91+
for i, node in enumerate(self.path):
92+
repr_string += f"{i}: {node}\n"
93+
return repr_string
94+
95+
96+
class SafeIntervalPathPlanner:
97+
grid: Grid
98+
start: Position
99+
goal: Position
100+
101+
def __init__(self, grid: Grid, start: Position, goal: Position):
102+
self.grid = grid
103+
self.start = start
104+
self.goal = goal
105+
106+
# Seed randomness for reproducibility
107+
RANDOM_SEED = 50
108+
random.seed(RANDOM_SEED)
109+
np.random.seed(RANDOM_SEED)
110+
111+
"""
112+
Generate a plan given the loaded problem statement. Raises an exception if it fails to find a path.
113+
Arguments:
114+
verbose (bool): set to True to print debug information
115+
"""
116+
def plan(self, verbose: bool = False) -> NodePath:
117+
118+
safe_intervals = self.grid.get_safe_intervals()
119+
120+
open_set: list[Node] = []
121+
first_node_interval = safe_intervals[self.start.x, self.start.y][0]
122+
heapq.heappush(
123+
open_set, Node(self.start, 0, self.calculate_heuristic(self.start), -1, first_node_interval)
124+
)
125+
126+
expanded_list: list[Node] = []
127+
visited_intervals = empty_2d_array_of_lists(self.grid.grid_size[0], self.grid.grid_size[1])
128+
while open_set:
129+
expanded_node: Node = heapq.heappop(open_set)
130+
if verbose:
131+
print("Expanded node:", expanded_node)
132+
133+
if expanded_node.time + 1 >= self.grid.time_limit:
134+
if verbose:
135+
print(f"\tSkipping node that is past time limit: {expanded_node}")
136+
continue
137+
138+
if expanded_node.position == self.goal:
139+
print(f"Found path to goal after {len(expanded_list)} expansions")
140+
path = []
141+
path_walker: Node = expanded_node
142+
while True:
143+
path.append(path_walker)
144+
if path_walker.parent_index == -1:
145+
break
146+
path_walker = expanded_list[path_walker.parent_index]
147+
148+
# reverse path so it goes start -> goal
149+
path.reverse()
150+
return NodePath(path)
151+
152+
expanded_idx = len(expanded_list)
153+
expanded_list.append(expanded_node)
154+
entry_time_and_node = EntryTimeAndInterval(expanded_node.time, expanded_node.interval)
155+
add_entry_to_visited_intervals_array(entry_time_and_node, visited_intervals, expanded_node)
156+
157+
for child in self.generate_successors(expanded_node, expanded_idx, safe_intervals, visited_intervals):
158+
heapq.heappush(open_set, child)
159+
160+
raise Exception("No path found")
161+
162+
"""
163+
Generate list of possible successors of the provided `parent_node` that are worth expanding
164+
"""
165+
def generate_successors(
166+
self, parent_node: Node, parent_node_idx: int, intervals: np.ndarray, visited_intervals: np.ndarray
167+
) -> list[Node]:
168+
new_nodes = []
169+
diffs = [
170+
Position(0, 0),
171+
Position(1, 0),
172+
Position(-1, 0),
173+
Position(0, 1),
174+
Position(0, -1),
175+
]
176+
for diff in diffs:
177+
new_pos = parent_node.position + diff
178+
if not self.grid.inside_grid_bounds(new_pos):
179+
continue
180+
181+
current_interval = parent_node.interval
182+
183+
new_cell_intervals: list[Interval] = intervals[new_pos.x, new_pos.y]
184+
for interval in new_cell_intervals:
185+
# if interval starts after current ends, break
186+
# assumption: intervals are sorted by start time, so all future intervals will hit this condition as well
187+
if interval.start_time > current_interval.end_time:
188+
break
189+
190+
# if interval ends before current starts, skip
191+
if interval.end_time < current_interval.start_time:
192+
continue
193+
194+
# if we have already expanded a node in this interval with a <= starting time, skip
195+
better_node_expanded = False
196+
for visited in visited_intervals[new_pos.x, new_pos.y]:
197+
if interval == visited.interval and visited.entry_time <= parent_node.time + 1:
198+
better_node_expanded = True
199+
break
200+
if better_node_expanded:
201+
continue
202+
203+
# We know there is a node worth expanding. Generate successor at the earliest possible time the
204+
# new interval can be entered
205+
for possible_t in range(max(parent_node.time + 1, interval.start_time), min(current_interval.end_time, interval.end_time)):
206+
if self.grid.valid_position(new_pos, possible_t):
207+
new_nodes.append(Node(
208+
new_pos,
209+
# entry is max of interval start and parent node time + 1 (get there as soon as possible)
210+
max(interval.start_time, parent_node.time + 1),
211+
self.calculate_heuristic(new_pos),
212+
parent_node_idx,
213+
interval,
214+
))
215+
# break because all t's after this will make nodes with a higher cost, the same heuristic, and are in the same interval
216+
break
217+
218+
return new_nodes
219+
220+
"""
221+
Calculate the heuristic for a given position - Manhattan distance to the goal
222+
"""
223+
def calculate_heuristic(self, position) -> int:
224+
diff = self.goal - position
225+
return abs(diff.x) + abs(diff.y)
226+
227+
228+
"""
229+
Adds a new entry to the visited intervals array. If the entry is already present, the entry time is updated if the new
230+
entry time is better. Otherwise, the entry is added to `visited_intervals` at the position of `expanded_node`.
231+
"""
232+
def add_entry_to_visited_intervals_array(entry_time_and_interval: EntryTimeAndInterval, visited_intervals: np.ndarray, expanded_node: Node):
233+
# if entry is present, update entry time if better
234+
for existing_entry_and_interval in visited_intervals[expanded_node.position.x, expanded_node.position.y]:
235+
if existing_entry_and_interval.interval == entry_time_and_interval.interval:
236+
existing_entry_and_interval.entry_time = min(existing_entry_and_interval.entry_time, entry_time_and_interval.entry_time)
237+
238+
# Otherwise, append
239+
visited_intervals[expanded_node.position.x, expanded_node.position.y].append(entry_time_and_interval)
240+
241+
242+
show_animation = True
243+
verbose = False
244+
245+
def main():
246+
start = Position(1, 18)
247+
goal = Position(19, 19)
248+
grid_side_length = 21
249+
250+
start_time = time.time()
251+
252+
grid = Grid(
253+
np.array([grid_side_length, grid_side_length]),
254+
num_obstacles=250,
255+
obstacle_avoid_points=[start, goal],
256+
obstacle_arrangement=ObstacleArrangement.ARRANGEMENT1,
257+
# obstacle_arrangement=ObstacleArrangement.RANDOM,
258+
)
259+
260+
planner = SafeIntervalPathPlanner(grid, start, goal)
261+
path = planner.plan(verbose)
262+
runtime = time.time() - start_time
263+
print(f"Planning took: {runtime:.5f} seconds")
264+
265+
if verbose:
266+
print(f"Path: {path}")
267+
268+
if not show_animation:
269+
return
270+
271+
fig = plt.figure(figsize=(10, 7))
272+
ax = fig.add_subplot(
273+
autoscale_on=False,
274+
xlim=(0, grid.grid_size[0] - 1),
275+
ylim=(0, grid.grid_size[1] - 1),
276+
)
277+
ax.set_aspect("equal")
278+
ax.grid()
279+
ax.set_xticks(np.arange(0, grid_side_length, 1))
280+
ax.set_yticks(np.arange(0, grid_side_length, 1))
281+
282+
(start_and_goal,) = ax.plot([], [], "mD", ms=15, label="Start and Goal")
283+
start_and_goal.set_data([start.x, goal.x], [start.y, goal.y])
284+
(obs_points,) = ax.plot([], [], "ro", ms=15, label="Obstacles")
285+
(path_points,) = ax.plot([], [], "bo", ms=10, label="Path Found")
286+
ax.legend(bbox_to_anchor=(1.05, 1))
287+
288+
# for stopping simulation with the esc key.
289+
plt.gcf().canvas.mpl_connect(
290+
"key_release_event", lambda event: [exit(0) if event.key == "escape" else None]
291+
)
292+
293+
for i in range(0, path.goal_reached_time() + 1):
294+
obs_positions = grid.get_obstacle_positions_at_time(i)
295+
obs_points.set_data(obs_positions[0], obs_positions[1])
296+
path_position = path.get_position(i)
297+
path_points.set_data([path_position.x], [path_position.y])
298+
plt.pause(0.2)
299+
plt.show()
300+
301+
302+
if __name__ == "__main__":
303+
main()

0 commit comments

Comments
 (0)