Skip to content

Commit cfabf5a

Browse files
read polylines from lanelet map
Signed-off-by: Daniel Sanchez <[email protected]>
1 parent 0b32330 commit cfabf5a

File tree

6 files changed

+582
-13
lines changed

6 files changed

+582
-13
lines changed

autoware_mtr/dataclass/history.py

-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def update_state(self, state: AgentState, info: OriginalInfo | None = None) -> N
3737
Args:
3838
state (AgentState): Agent state.
3939
"""
40-
print(f"State being added: {state.xyz}")
4140

4241
uuid = state.uuid
4342

@@ -93,8 +92,6 @@ def is_ancient(latest_timestamp: float, current_timestamp: float, threshold: flo
9392
which means ancient.
9493
"""
9594
timestamp_diff = abs(current_timestamp - latest_timestamp)
96-
print("timestamp_diff")
97-
print("threshold ", threshold)
9895
return timestamp_diff > threshold
9996

10097
def as_trajectory(self, *, latest: bool = False) -> tuple[AgentTrajectory, list[str]]:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
setup(
1010
name=package_name,
1111
version="0.0.0",
12-
packages=find_packages(include=["autoware_mtr*", "autoware_mtr_python*","awml_pred*","projects*"],exclude=["test"]),
12+
packages=find_packages(include=["autoware_mtr*", "autoware_mtr_python*","awml_pred*","projects*","utils"],exclude=["test"]),
1313
data_files=[
1414
(osp.join("share", package_name), ["package.xml"]),
1515
(osp.join("share", package_name, "config"), glob("config/*")),

src/mtr_node.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@
1919

2020
from numpy.typing import NDArray
2121
from rcl_interfaces.msg import ParameterDescriptor
22+
from utils.polyline import TargetCentricPolyline
2223

2324
from autoware_perception_msgs.msg import PredictedObjects
24-
from awml_pred.dataclass import Trajectory
25+
from awml_pred.dataclass import AWMLStaticMap, AWMLAgentScenario
2526
from awml_pred.common import Config, create_logger, get_num_devices, init_dist_pytorch, init_dist_slurm, load_checkpoint
2627
from awml_pred.models import build_model
2728
from awml_pred.deploy.apis.torch2onnx import _load_inputs
29+
30+
from utils.lanelet_converter import convert_lanelet
2831
from autoware_mtr.conversion.ego import from_odometry
2932
from autoware_mtr.conversion.misc import timestamp2ms
30-
from autoware_mtr.conversion.lanelet import convert_lanelet
3133
from autoware_mtr.conversion.trajectory import get_relative_history
3234
from autoware_mtr.datatype import AgentLabel
3335
from autoware_mtr.geometry import rotate_along_z
@@ -146,8 +148,19 @@ def __init__(self) -> None:
146148
)
147149

148150
self._history = AgentHistory(max_length=num_timestamp)
149-
150-
self._lane_segments: list[LaneSegment] = convert_lanelet(lanelet_file)
151+
self._awml_static_map: AWMLStaticMap = convert_lanelet(lanelet_file)
152+
153+
num_polylines: int = 768
154+
num_points: int = 20
155+
break_distance: float = 1.0
156+
center_offset: tuple[float, float] = (0.0, 0.0)
157+
158+
self._preprocess_polyline = TargetCentricPolyline(
159+
num_polylines=num_polylines,
160+
num_points=num_points,
161+
break_distance=break_distance,
162+
center_offset=center_offset,
163+
)
151164

152165
self._label_ids = [AgentLabel.from_str(label).value for label in labels]
153166

@@ -187,14 +200,21 @@ def _callback(self, msg: Odometry) -> None:
187200
)
188201
# print("current_ego.xyz",current_ego.xyz)
189202
self._history.update_state(current_ego, info)
190-
191203
dummy_input = _load_inputs(self.deploy_cfg.input_shapes)
192204

193205
# pre-process
194-
past_embed = self._preprocess(self._history, current_ego, self._lane_segments)
206+
past_embed, polyline_info = self._preprocess(self._history, current_ego, self._awml_static_map)
207+
208+
print("polyline info ", polyline_info["polylines"].shape)
209+
print("polyline mask info ", polyline_info["polylines_mask"].shape)
210+
195211
if self.count > 11:
196212
dummy_input["obj_trajs"] = torch.Tensor(past_embed).cuda()
197-
print(" dummy_input[obj_trajs]", dummy_input["obj_trajs"].shape)
213+
print("Before dummy_input[obj_trajs_last_pos] ", dummy_input["obj_trajs_last_pos"].shape )
214+
dummy_input["obj_trajs_last_pos"] = torch.Tensor(current_ego.xyz.reshape((1,1,3))).cuda()
215+
print(" dummy_input[obj_trajs_last_pos]", dummy_input["obj_trajs_last_pos"].shape)
216+
dummy_input["map_polylines"] = torch.Tensor(polyline_info["polylines"]).cuda()
217+
dummy_input["map_polylines_mask"] = torch.Tensor(polyline_info["polylines_mask"]).cuda()
198218

199219
if self.count <= 11:
200220
self.count = self.count + 1
@@ -351,7 +371,7 @@ def _preprocess(
351371
self,
352372
history: AgentHistory,
353373
current_ego: AgentState,
354-
lane_segments: list[LaneSegment],
374+
awml_static_map: AWMLStaticMap,
355375
) -> ModelInput:
356376
"""Run preprocess.
357377
@@ -371,14 +391,15 @@ def _preprocess(
371391
# return agent_current_xyz
372392

373393
# ego_input = get_current_ego_input(current_ego)
394+
polyline_info = self._preprocess_polyline(static_map=self._awml_static_map,target_state=current_ego,num_target=1)
374395
relative_history = get_relative_history(current_ego,self._history.histories[self._ego_uuid])
375396
past_embed = self.get_ego_past(relative_history)
376397

377398
# print("past_embed", past_embed)
378399
# print("ego_input", ego_input)
379400
# print("past_embed shape", past_embed.shape)
380401
# print("ego_input shape ", ego_input.shape)
381-
return past_embed
402+
return past_embed, polyline_info
382403

383404

384405
def main(args=None) -> None:

utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)