Skip to content

Commit c2b8cea

Browse files
Chore/clean files (#20)
* remove and delete some unused files Signed-off-by: Daniel <[email protected]> * remove more unused files Signed-off-by: Daniel <[email protected]> * delete more files Signed-off-by: Daniel <[email protected]> * delete more awml pred files Signed-off-by: Daniel <[email protected]> --------- Signed-off-by: Daniel <[email protected]>
1 parent aa4a075 commit c2b8cea

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+70
-7210
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Overview
44

55
The `autoware_mtr_python` package is part of the Autoware Universe project, specifically within the perception module. This package provides functionalities for the MTR (Motion Transformer) node implemented in `mtr_node.py`. For more information about the MTR model, please refer to the original paper publication: [Motion Transformer with Global Intention Localization and Local Movement Refinement]{https://arxiv.org/abs/2209.13508} by Shi et. al.
6+
Furthermore, this project takes many of its tools from the [AWMLprediction repository]{https://github.com/tier4/AWMLprediction}.
67

78
## Main Functionality
89

autoware_mtr/dataclass/polyline.py

+55-70
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,70 @@
11
from __future__ import annotations
22

3-
from dataclasses import dataclass
4-
from typing import ClassVar
5-
from typing import Final
3+
from typing import TYPE_CHECKING, ClassVar
64

7-
from autoware_mtr.datatype import PolylineLabel
85
import numpy as np
9-
from numpy.typing import NDArray
6+
# from dataclasses import field
7+
from attr import define, field
8+
from typing_extensions import Self
109

11-
__all__ = ("Polyline",)
10+
from awml_pred.datatype import MapType
1211

12+
from awml_pred.dataclass.utils import to_np_f32
1313

14-
# TODO(ktro2828): Type definition
14+
if TYPE_CHECKING:
15+
from awml_pred.typing import NDArray, NDArrayF32
1516

17+
__all__ = ["Polyline"]
1618

17-
@dataclass
19+
20+
@define
1821
class Polyline:
19-
"""
20-
A dataclass of Polyline.
22+
"""A dataclass of Polyline.
2123
2224
Attributes
2325
----------
24-
polyline_type (PolylineType): `PolylineType` instance.
25-
waypoints (NDArray): Waypoints of polyline.
26+
polyline_type (MapType): Type of polyline.
27+
waypoints (NDArrayF32): Waypoints of polyline.
2628
2729
"""
2830

29-
polyline_type: PolylineLabel
30-
waypoints: NDArray
31+
polyline_type: MapType = field()
32+
waypoints: NDArrayF32 = field(converter=to_np_f32)
3133

3234
# NOTE: For the 1DArray indices must be a list.
3335
XYZ_IDX: ClassVar[list[int]] = [0, 1, 2]
3436
XY_IDX: ClassVar[list[int]] = [0, 1]
3537
FULL_DIM3D: ClassVar[int] = 7
3638
FULL_DIM2D: ClassVar[int] = 5
3739

38-
def __post_init__(self) -> None:
39-
if not isinstance(self.waypoints, np.ndarray):
40-
self.waypoints = np.array(self.waypoints, dtype=np.float32)
41-
assert isinstance(self.waypoints, np.ndarray)
42-
min_ndim: Final[int] = 1
43-
point_dim: Final[int] = 3
44-
assert self.waypoints.ndim > min_ndim and self.waypoints.shape[1] == point_dim
45-
assert isinstance(self.polyline_type, PolylineLabel)
40+
@polyline_type.validator
41+
def _check_type(self, attr, value) -> None:
42+
if not isinstance(value, MapType):
43+
raise TypeError(f"Unexpected type of {attr.name}: {type(value)}")
44+
45+
@waypoints.validator
46+
def _check_dim(self, attribute, value) -> None:
47+
if value.ndim < 1 or value.shape[1] != 3:
48+
raise ValueError(f"Unexpected {attribute.name} dimensions.")
49+
50+
@classmethod
51+
def from_dict(cls, data: dict) -> Self:
52+
"""Construct an instance from dict data.
53+
54+
Args:
55+
----
56+
data (dict): Dict data of `Polyline`.
57+
58+
Returns:
59+
-------
60+
Polyline: Constructed instance.
61+
62+
"""
63+
return cls(**data)
4664

4765
@property
4866
def xyz(self) -> NDArray:
49-
"""
50-
Return 3D positions.
67+
"""Return 3D positions.
5168
5269
Returns
5370
-------
@@ -62,8 +79,7 @@ def xyz(self, xyz: NDArray) -> None:
6279

6380
@property
6481
def xy(self) -> NDArray:
65-
"""
66-
Return 2D positions.
82+
"""Return 2D positions.
6783
6884
Returns
6985
-------
@@ -78,8 +94,7 @@ def xy(self, xy: NDArray) -> None:
7894

7995
@property
8096
def dxyz(self) -> NDArray:
81-
"""
82-
Return 3D normalized directions. The first element always becomes (0, 0, 0).
97+
"""Return 3D normalized directions. The first element always becomes (0, 0, 0).
8398
8499
Returns
85100
-------
@@ -89,14 +104,12 @@ def dxyz(self) -> NDArray:
89104
if self.is_empty():
90105
return np.empty((0, 3), dtype=np.float32)
91106
diff = np.diff(self.xyz, axis=0, prepend=self.xyz[0].reshape(-1, 3))
92-
norm = np.linalg.norm(diff, axis=-1, keepdims=True)
93-
zero: Final[float] = 0.0
94-
return np.divide(diff, norm, where=(diff != zero) & (norm != zero))
107+
norm = np.clip(np.linalg.norm(diff, axis=-1, keepdims=True), a_min=1e-6, a_max=1e9)
108+
return np.divide(diff, norm)
95109

96110
@property
97111
def dxy(self) -> NDArray:
98-
"""
99-
Return 2D normalized directions. The first element always becomes (0, 0).
112+
"""Return 2D normalized directions. The first element always becomes (0, 0).
100113
101114
Returns
102115
-------
@@ -106,40 +119,14 @@ def dxy(self) -> NDArray:
106119
if self.is_empty():
107120
return np.empty((0, 2), dtype=np.float32)
108121
diff = np.diff(self.xy, axis=0, prepend=self.xy[0].reshape(-1, 2))
109-
norm = np.linalg.norm(diff, axis=-1, keepdims=True)
110-
zero: Final[float] = 0.0
111-
return np.divide(diff, norm, where=(diff != zero) & (norm != zero))
112-
113-
@property
114-
def type_id(self) -> int:
115-
"""
116-
Return the type ID in `int`.
117-
118-
Returns
119-
-------
120-
int: Type ID.
121-
122-
"""
123-
return self.polyline_type.value
124-
125-
@property
126-
def type_str(self) -> str:
127-
"""
128-
Return the type in `str`.
129-
130-
Returns
131-
-------
132-
str: Type in `str`.
133-
134-
"""
135-
return self.polyline_type.as_str()
122+
norm = np.clip(np.linalg.norm(diff, axis=-1, keepdims=True), a_min=1e-6, a_max=1e9)
123+
return np.divide(diff, norm)
136124

137125
def __len__(self) -> int:
138126
return len(self.waypoints)
139127

140128
def is_empty(self) -> bool:
141-
"""
142-
Indicate whether waypoints is empty array.
129+
"""Indicate whether waypoints is empty array.
143130
144131
Returns
145132
-------
@@ -148,21 +135,19 @@ def is_empty(self) -> bool:
148135
"""
149136
return len(self.waypoints) == 0
150137

151-
def as_array(self, *, full: bool = False, as_3d: bool = True) -> NDArray:
152-
"""
153-
Return the polyline as `NDArray`.
138+
def as_array(self, *, full: bool = False, as_3d: bool = True) -> NDArrayF32:
139+
"""Return the polyline as `NDArray`.
154140
155141
Args:
156142
----
157-
full (bool, optional): Indicates whether to return
158-
`(x, y, z, dx, dy, dz, type_id)`. If `False`, returns `(x, y, z)`.
159-
Defaults to False.
143+
full (bool, optional): Indicates whether to return `(x, y, z, dx, dy, dz, type_id)`.
144+
If `False`, returns `(x, y, z)`. Defaults to False.
160145
as_3d (bool, optional): If `True` returns array containing 3D coordinates.
161146
Otherwise, 2D coordinates. Defaults to True.
162147
163148
Returns:
164149
-------
165-
NDArray: Polyline array.
150+
NDArrayF32: Polyline array.
166151
167152
"""
168153
if full:
@@ -174,7 +159,7 @@ def as_array(self, *, full: bool = False, as_3d: bool = True) -> NDArray:
174159
)
175160

176161
shape = self.waypoints.shape[:-1]
177-
type_id = np.full((*shape, 1), self.type_id)
162+
type_id = np.full((*shape, 1), self.polyline_type.value)
178163
return (
179164
np.concatenate([self.xyz, self.dxyz, type_id], axis=1, dtype=np.float32)
180165
if as_3d
File renamed without changes.

awml_pred/dataclass/__init__.py

-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1 @@
1-
from .agent import * # noqa
2-
from .polyline import * # noqa
3-
from .prediction import * # noqa
4-
from .scenario import * # noqa
5-
from .static_map import * # noqa
61
from .utils import * # noqa

0 commit comments

Comments
 (0)