Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translate transform for pointclouds #255

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions kaolin/transforms/pointcloudfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,52 @@
EPS = 1e-6


def shift(cloud: Union[torch.Tensor, PointCloud],
shf: Union[float, int, torch.Tensor],
inplace: Optional[bool] = True):
"""Shift the input pointcloud by a shift factor.

Args:
cloud (torch.Tensor or kaolin.rep.PointCloud): pointcloud (ndims >= 2).
shf (float, int, torch.Tensor): shift factor (scaler, or tensor).
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved
inplace (bool, optional): Bool to make the transform in-place

Returns:
(torch.Tensor): shifted pointcloud pf the same shape as input.
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved

Shape:
- cloud: :math:`(B x N x D)` (or) :math:`(N x D)`, where :math:`(B)`
is the batchsize, :math:`(N)` is the number of points per cloud,
and :math:`(D)` is the dimensionality of each cloud.
- shf: :math:`(1)` or :math:`(B)`.

Example:
>>> points = torch.rand(1000,3)
>>> points2 = shift(points, torch.FloatTensor([3]))
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved
"""

if isinstance(cloud, np.ndarray):
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved
cloud = torch.from_numpy(cloud)

if isinstance(shf, np.ndarray):
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved
shf = torch.from_numpy(shf)

if isinstance(cloud, PointCloud):
cloud = cloud.points

if isinstance(shf, int) or isinstance(shf, float):
shf = torch.Tensor([shf]).to(cloud.device)
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved

helpers._assert_tensor(cloud)
helpers._assert_tensor(shf)
helpers._assert_dim_ge(cloud, 2)
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved

if not inplace:
cloud = cloud.clone()
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved

return shf + cloud


def scale(cloud: Union[torch.Tensor, PointCloud],
scf: Union[float, int, torch.Tensor],
inplace: Optional[bool] = True):
Expand Down Expand Up @@ -74,6 +120,56 @@ def scale(cloud: Union[torch.Tensor, PointCloud],
return scf * cloud


def translate(cloud: Union[torch.Tensor, PointCloud], tranmat: torch.Tensor,
inplace: Optional[bool] = True):
"""Translate the input pointcloud by a translation matrix.

Args:
cloud (Tensor or np.array): pointcloud (ndims = 2 or 3)
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved
tranmat (Tensor or np.array): translation matrix (1 x 3, 1 per cloud).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not consistent with the type hints (torch.Tensor).
Also torch.Tensor should be preferred over Tensor.

Copy link
Collaborator

@Caenorst Caenorst May 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unresolving: the documentation still indicate that tranmat can be torch.Tensor or np.array while the type hints is indicating torch.Tensor only

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh missed it. I'll update in the next commit.


Returns:
cloud_tran (Tensor): Translated pointcloud of the same shape as input.
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved

Shape:
- cloud: :math:`(B x N x 3)` (or) :math:`(N x 3)`, where :math:`(B)`
is the batchsize, :math:`(N)` is the number of points per cloud,
and :math:`(3)` is the dimensionality of each cloud.
- tranmat: :math:`(1, 3)` or :math:`(B, 1, 3)`.

Example:
>>> points = torch.rand(1000,3)
>>> t_mat = torch.rand(1,3)
>>> points2 = translate(points, t_mat)
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved

"""
if isinstance(cloud, np.ndarray):
cloud = torch.from_numpy(cloud)
if isinstance(cloud, PointCloud):
Caenorst marked this conversation as resolved.
Show resolved Hide resolved
cloud = cloud.points
if isinstance(tranmat, np.ndarray):
trainmat = torch.from_numpy(tranmat)

helpers._assert_tensor(cloud)
helpers._assert_tensor(tranmat)
helpers._assert_dim_ge(cloud, 2)
helpers._assert_dim_ge(tranmat, 2)
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved
# Rotation matrix must have last two dimensions of shape 3.
helpers._assert_shape_eq(tranmat, (1, 3), dim=-1)
helpers._assert_shape_eq(tranmat, (1, 3), dim=-2)

if not inplace:
Caenorst marked this conversation as resolved.
Show resolved Hide resolved
cloud = cloud.clone()

if tranmat.dim() == 2 and cloud.dim() == 2:
cloud = torch.add(tranmat, cloud)
else:
if tranmat.dim() == 2:
Caenorst marked this conversation as resolved.
Show resolved Hide resolved
tranmat = tranmat.expand(cloud.shape[0], 1, 3)
cloud = torch.add(tranmat, cloud)

return cloud

def rotate(cloud: Union[torch.Tensor, PointCloud], rotmat: torch.Tensor,
inplace: Optional[bool] = True):
"""Rotates the the input pointcloud by a rotation matrix.
Expand Down
58 changes: 58 additions & 0 deletions kaolin/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,35 @@ def __call__(self, arr: np.ndarray):
return torch.from_numpy(arr)


class ShiftPointCloud(Transform):
r"""Shift a pointcloud with respect a fixed shift factor.
Given a shift factor `shf`, this transform will shift each point in the
pointcloud, i.e.,
``cloud = shf + cloud``

Args:
shf (int or float or torch.Tensor): Shift pofactorint by which input
clouds are to be shifted.
inplace (bool, optional): Whether or not the transformation should be
in-place (default: True).
"""

def __init__(self, shf: Union[int, float, torch.Tensor],
inplace: Optional[bool] = True):
self.shf = shf
self.inplace = inplace

def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
"""
Args:
cloud (torch.Tensor or PointCloud): Pointcloud to be shifted.

Returns:
(torch.Tensor or PointCloud): Shifted pointcloud.
"""
return pcfunc.shift(cloud, shf=self.shf, inplace=self.inplace)


class ScalePointCloud(Transform):
"""Scale a pointcloud with a fixed scaling factor.
Given a scale factor `scf`, this transform will scale each point in the
Expand Down Expand Up @@ -231,6 +260,35 @@ def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
return pcfunc.scale(cloud, scf=self.scf, inplace=self.inplace)


class TranslatePointCloud(Transform):
r"""Translate a pointcloud with a given translation matrix.
Given a :math:`1 \times 3` translation matrix, this transform will
translate each point in the cloud by the translation matrix specified.

Args:
tranmat (torch.Tensor): Translation matrix that specifies the translation
to be applied to the pointcloud (shape: :math:`1 \times 3`).
inplace (bool, optional): Bool to make this operation in-place.

TODO: Example.
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved

"""

def __init__(self, tranmat: torch.Tensor, inplace: Optional[bool] = True):
self.tranmat = tranmat
self.inplace = inplace

def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
"""
Args:
cloud (torch.Tensor or PointCloud): Input pointcloud to be translated.

Returns:
(torch.Tensor or PointCloud): Translated pointcloud.
"""
return pcfunc.translate(cloud, tranmat=self.tranmat, inplace=self.inplace)


class RotatePointCloud(Transform):
r"""Rotate a pointcloud with a given rotation matrix.
Given a :math:`3 \times 3` rotation matrix, this transform will rotate each
Expand Down
15 changes: 15 additions & 0 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ def test_numpy_to_tensor(device='cpu'):
assert torch.is_tensor(ten)


def test_shift_pointcloud(device='cpu'):
unit_shift = kal.transforms.ShiftPointCloud(-1)
pc = torch.ones(4, 3)
pc_ = unit_shift(pc)
assert_allclose(pc_, torch.zeros(4, 3))


def test_scale_pointcloud(device='cpu'):
twice = kal.transforms.ScalePointCloud(2)
halve = kal.transforms.ScalePointCloud(0.5)
Expand All @@ -38,6 +45,14 @@ def test_scale_pointcloud(device='cpu'):
assert_allclose(pc_, torch.ones(4, 3))


def test_translate_pointcloud(device='cpu'):
pushkalkatara marked this conversation as resolved.
Show resolved Hide resolved
pc = torch.ones(4, 3)
tmat = torch.tensor([[-1.0,-1.0,-1.0]])
translate = kal.transforms.TranslatePointCloud(tmat)
pc_ = translate(pc)
assert_allclose(pc_, torch.zeros(4, 3))


def test_rotate_pointcloud(device='cpu'):
pc = torch.ones(4, 3)
rmat = 2 * torch.eye(3)
Expand Down