19
19
20
20
from numpy .typing import NDArray
21
21
from rcl_interfaces .msg import ParameterDescriptor
22
+ from utils .polyline import TargetCentricPolyline
22
23
23
24
from autoware_perception_msgs .msg import PredictedObjects
24
- from awml_pred .dataclass import Trajectory
25
+ from awml_pred .dataclass import AWMLStaticMap , AWMLAgentScenario
25
26
from awml_pred .common import Config , create_logger , get_num_devices , init_dist_pytorch , init_dist_slurm , load_checkpoint
26
27
from awml_pred .models import build_model
27
28
from awml_pred .deploy .apis .torch2onnx import _load_inputs
29
+
30
+ from utils .lanelet_converter import convert_lanelet
28
31
from autoware_mtr .conversion .ego import from_odometry
29
32
from autoware_mtr .conversion .misc import timestamp2ms
30
- from autoware_mtr .conversion .lanelet import convert_lanelet
31
33
from autoware_mtr .conversion .trajectory import get_relative_history
32
34
from autoware_mtr .datatype import AgentLabel
33
35
from autoware_mtr .geometry import rotate_along_z
@@ -146,8 +148,19 @@ def __init__(self) -> None:
146
148
)
147
149
148
150
self ._history = AgentHistory (max_length = num_timestamp )
149
-
150
- self ._lane_segments : list [LaneSegment ] = convert_lanelet (lanelet_file )
151
+ self ._awml_static_map : AWMLStaticMap = convert_lanelet (lanelet_file )
152
+
153
+ num_polylines : int = 768
154
+ num_points : int = 20
155
+ break_distance : float = 1.0
156
+ center_offset : tuple [float , float ] = (0.0 , 0.0 )
157
+
158
+ self ._preprocess_polyline = TargetCentricPolyline (
159
+ num_polylines = num_polylines ,
160
+ num_points = num_points ,
161
+ break_distance = break_distance ,
162
+ center_offset = center_offset ,
163
+ )
151
164
152
165
self ._label_ids = [AgentLabel .from_str (label ).value for label in labels ]
153
166
@@ -187,14 +200,21 @@ def _callback(self, msg: Odometry) -> None:
187
200
)
188
201
# print("current_ego.xyz",current_ego.xyz)
189
202
self ._history .update_state (current_ego , info )
190
-
191
203
dummy_input = _load_inputs (self .deploy_cfg .input_shapes )
192
204
193
205
# pre-process
194
- past_embed = self ._preprocess (self ._history , current_ego , self ._lane_segments )
206
+ past_embed , polyline_info = self ._preprocess (self ._history , current_ego , self ._awml_static_map )
207
+
208
+ print ("polyline info " , polyline_info ["polylines" ].shape )
209
+ print ("polyline mask info " , polyline_info ["polylines_mask" ].shape )
210
+
195
211
if self .count > 11 :
196
212
dummy_input ["obj_trajs" ] = torch .Tensor (past_embed ).cuda ()
197
- print (" dummy_input[obj_trajs]" , dummy_input ["obj_trajs" ].shape )
213
+ print ("Before dummy_input[obj_trajs_last_pos] " , dummy_input ["obj_trajs_last_pos" ].shape )
214
+ dummy_input ["obj_trajs_last_pos" ] = torch .Tensor (current_ego .xyz .reshape ((1 ,1 ,3 ))).cuda ()
215
+ print (" dummy_input[obj_trajs_last_pos]" , dummy_input ["obj_trajs_last_pos" ].shape )
216
+ dummy_input ["map_polylines" ] = torch .Tensor (polyline_info ["polylines" ]).cuda ()
217
+ dummy_input ["map_polylines_mask" ] = torch .Tensor (polyline_info ["polylines_mask" ]).cuda ()
198
218
199
219
if self .count <= 11 :
200
220
self .count = self .count + 1
@@ -351,7 +371,7 @@ def _preprocess(
351
371
self ,
352
372
history : AgentHistory ,
353
373
current_ego : AgentState ,
354
- lane_segments : list [ LaneSegment ] ,
374
+ awml_static_map : AWMLStaticMap ,
355
375
) -> ModelInput :
356
376
"""Run preprocess.
357
377
@@ -371,14 +391,15 @@ def _preprocess(
371
391
# return agent_current_xyz
372
392
373
393
# ego_input = get_current_ego_input(current_ego)
394
+ polyline_info = self ._preprocess_polyline (static_map = self ._awml_static_map ,target_state = current_ego ,num_target = 1 )
374
395
relative_history = get_relative_history (current_ego ,self ._history .histories [self ._ego_uuid ])
375
396
past_embed = self .get_ego_past (relative_history )
376
397
377
398
# print("past_embed", past_embed)
378
399
# print("ego_input", ego_input)
379
400
# print("past_embed shape", past_embed.shape)
380
401
# print("ego_input shape ", ego_input.shape)
381
- return past_embed
402
+ return past_embed , polyline_info
382
403
383
404
384
405
def main (args = None ) -> None :
0 commit comments