Skip to content

Commit e54efb2

Browse files
committed
adding translate transform for point clouds
1 parent c3ec548 commit e54efb2

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

kaolin/transforms/pointcloudfunc.py

+97
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,53 @@
2525
EPS = 1e-6
2626

2727

28+
def shift(cloud: Union[torch.Tensor, PointCloud],
29+
shf: Union[float, int, torch.Tensor],
30+
inplace: Optional[bool] = True):
31+
"""Shift the input pointcloud by a shift factor.
32+
33+
Args:
34+
cloud (torch.Tensor or kaolin.rep.PointCloud): pointcloud (ndims >= 2).
35+
shf (float, int, torch.Tensor): shift factor (scaler, or tensor).
36+
inplace (bool, optional): Bool to make the transform in-place
37+
38+
Returns:
39+
(torch.Tensor): shifted pointcloud pf the same shape as input.
40+
41+
Shape:
42+
- cloud: :math:`(B x N x D)` (or) :math:`(N x D)`, where :math:`(B)`
43+
is the batchsize, :math:`(N)` is the number of points per cloud,
44+
and :math:`(D)` is the dimensionality of each cloud.
45+
- shf: :math:`(1)` or :math:`(B)`.
46+
47+
Example:
48+
>>> points = torch.rand(1000,3)
49+
>>> points2 = shift(points, torch.FloatTensor([3]))
50+
"""
51+
52+
if isinstance(cloud, np.ndarray):
53+
cloud = torch.from_numpy(cloud)
54+
55+
if isinstance(shf, np.ndarray):
56+
shf = torch.from_numpy(shf)
57+
58+
if isinstance(cloud, PointCloud):
59+
cloud = cloud.points
60+
61+
if isinstance(shf, int) or isinstance(shf, float):
62+
shf = torch.Tensor([shf]).to(cloud.device)
63+
64+
helpers._assert_tensor(cloud)
65+
helpers._assert_tensor(shf)
66+
helpers._assert_dim_ge(cloud, 2)
67+
helpers._assert_gt(shf, 0.)
68+
69+
if not inplace:
70+
cloud = cloud.clone()
71+
72+
return shf + cloud
73+
74+
2875
def scale(cloud: Union[torch.Tensor, PointCloud],
2976
scf: Union[float, int, torch.Tensor],
3077
inplace: Optional[bool] = True):
@@ -74,6 +121,56 @@ def scale(cloud: Union[torch.Tensor, PointCloud],
74121
return scf * cloud
75122

76123

124+
def translate(cloud: Union[torch.Tensor, PointCloud], tranmat: torch.Tensor,
125+
inplace: Optional[bool] = True):
126+
"""Translate the input pointcloud by a translation matrix.
127+
128+
Args:
129+
cloud (Tensor or np.array): pointcloud (ndims = 2 or 3)
130+
tranmat (Tensor or np.array): translation matrix (1 x 3, 1 per cloud).
131+
132+
Returns:
133+
cloud_tran (Tensor): Translated pointcloud of the same shape as input.
134+
135+
Shape:
136+
- cloud: :math:`(B x N x 3)` (or) :math:`(N x 3)`, where :math:`(B)`
137+
is the batchsize, :math:`(N)` is the number of points per cloud,
138+
and :math:`(3)` is the dimensionality of each cloud.
139+
- tranmat: :math:`(1, 3)` or :math:`(B, 1, 3)`.
140+
141+
Example:
142+
>>> points = torch.rand(1000,3)
143+
>>> t_mat = torch.rand(1,3)
144+
>>> points2 = translate(points, t_mat)
145+
146+
"""
147+
if isinstance(cloud, np.ndarray):
148+
cloud = torch.from_numpy(cloud)
149+
if isinstance(cloud, PointCloud):
150+
cloud = cloud.points
151+
if isinstance(tranmat, np.ndarray):
152+
trainmat = torch.from_numpy(tranmat)
153+
154+
helpers._assert_tensor(cloud)
155+
helpers._assert_tensor(tranmat)
156+
helpers._assert_dim_ge(cloud, 2)
157+
helpers._assert_dim_ge(tranmat, 2)
158+
# Rotation matrix must have last two dimensions of shape 3.
159+
helpers._assert_shape_eq(tranmat, (1, 3), dim=-1)
160+
helpers._assert_shape_eq(tranmat, (1, 3), dim=-2)
161+
162+
if not inplace:
163+
cloud = cloud.clone()
164+
165+
if tranmat.dim() == 2 and cloud.dim() == 2:
166+
cloud = torch.add(tranmat, cloud)
167+
else:
168+
if tranmat.dim() == 2:
169+
tranmat = tranmat.expand(cloud.shape[0], 1, 3)
170+
cloud = torch.add(tranmat, cloud)
171+
172+
return cloud
173+
77174
def rotate(cloud: Union[torch.Tensor, PointCloud], rotmat: torch.Tensor,
78175
inplace: Optional[bool] = True):
79176
"""Rotates the the input pointcloud by a rotation matrix.

kaolin/transforms/transforms.py

+58
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,35 @@ def __call__(self, arr: np.ndarray):
198198
return torch.from_numpy(arr)
199199

200200

201+
class ShiftPointCloud(Transform):
202+
r"""Shift a pointcloud with respect a fixed shift factor.
203+
Given a shift factor `shf`, this transform will shift each point in the
204+
pointcloud, i.e.,
205+
``cloud = shf + cloud``
206+
207+
Args:
208+
shf (int or float or torch.Tensor): Shift pofactorint by which input
209+
clouds are to be shifted.
210+
inplace (bool, optional): Whether or not the transformation should be
211+
in-place (default: True).
212+
"""
213+
214+
def __init__(self, shf: Union[int, float, torch.Tensor],
215+
inplace: Optional[bool] = True):
216+
self.shf = shf
217+
self.inplace = inplace
218+
219+
def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
220+
"""
221+
Args:
222+
cloud (torch.Tensor or PointCloud): Pointcloud to be shifted.
223+
224+
Returns:
225+
(torch.Tensor or PointCloud): Shifted pointcloud.
226+
"""
227+
return pcfunc.shift(cloud, shf=self.shf, inplace=self.inplace)
228+
229+
201230
class ScalePointCloud(Transform):
202231
"""Scale a pointcloud with a fixed scaling factor.
203232
Given a scale factor `scf`, this transform will scale each point in the
@@ -231,6 +260,35 @@ def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
231260
return pcfunc.scale(cloud, scf=self.scf, inplace=self.inplace)
232261

233262

263+
class TranslatePointCloud(Transform):
264+
r"""Translate a pointcloud with a given translation matrix.
265+
Given a :math:`1 \times 3` translation matrix, this transform will
266+
translate each point in the cloud by the translation matrix specified.
267+
268+
Args:
269+
tranmat (torch.Tensor): Translation matrix that specifies the translation
270+
to be applied to the pointcloud (shape: :math:`1 \times 3`).
271+
inplace (bool, optional): Bool to make this operation in-place.
272+
273+
TODO: Example.
274+
275+
"""
276+
277+
def __init__(self, tranmat: torch.Tensor, inplace: Optional[bool] = True):
278+
self.tranmat = tranmat
279+
self.inplace = inplace
280+
281+
def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
282+
"""
283+
Args:
284+
cloud (torch.Tensor or PointCloud): Input pointcloud to be translated.
285+
286+
Returns:
287+
(torch.Tensor or PointCloud): Translated pointcloud.
288+
"""
289+
return pcfunc.translate(cloud, tranmat=self.tranmat, inplace=self.inplace)
290+
291+
234292
class RotatePointCloud(Transform):
235293
r"""Rotate a pointcloud with a given rotation matrix.
236294
Given a :math:`3 \times 3` rotation matrix, this transform will rotate each

0 commit comments

Comments
 (0)