diff --git a/deel/torchlip/functional.py b/deel/torchlip/functional.py index 5a362da..9143cdf 100644 --- a/deel/torchlip/functional.py +++ b/deel/torchlip/functional.py @@ -212,7 +212,9 @@ def max_min(input: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor: return torch.cat((F.relu(input), F.relu(-input)), dim=dim) -def group_sort(input: torch.Tensor, group_size: Optional[int] = None, dim : int = 1) -> torch.Tensor: +def group_sort( + input: torch.Tensor, group_size: Optional[int] = None, dim: int = 1 +) -> torch.Tensor: r""" Applies GroupSort activation on the given tensor. @@ -220,21 +222,28 @@ def group_sort(input: torch.Tensor, group_size: Optional[int] = None, dim : int :py:func:`group_sort_2` :py:func:`full_sort` """ - + if group_size is None or group_size > input.shape[dim]: group_size = input.shape[dim] if input.shape[dim] % group_size != 0: raise ValueError("The input size must be a multiple of the group size.") - new_shape = input.shape[:dim]+(input.shape[dim]//group_size,group_size)+input.shape[dim+1:] + new_shape = ( + input.shape[:dim] + + (input.shape[dim] // group_size, group_size) + + input.shape[dim + 1 :] + ) if group_size == 2: resh_input = input.view(new_shape) - a, b = torch.min(resh_input, dim+1,keepdim=True)[0], torch.max(resh_input, dim+1,keepdim=True)[0] - return torch.cat([a, b], dim=dim+1).view(input.shape) + a, b = ( + torch.min(resh_input, dim + 1, keepdim=True)[0], + torch.max(resh_input, dim + 1, keepdim=True)[0], + ) + return torch.cat([a, b], dim=dim + 1).view(input.shape) fv = input.reshape(new_shape) - return torch.sort(fv,dim=dim+1)[0].reshape(input.shape) + return torch.sort(fv, dim=dim + 1)[0].reshape(input.shape) def group_sort_2(input: torch.Tensor) -> torch.Tensor: diff --git a/tests/test_activations.py b/tests/test_activations.py index 568e053..309fe28 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -131,7 +131,7 @@ def test_GroupSort(group_size, img, expected): xnp = np.repeat( np.expand_dims(np.repeat(np.expand_dims(xn, -1), 28, -1), -1), 28, -1 ) - xnp = uft.to_NCHW_inv(xnp) # move channel if needed (TF) + xnp = uft.to_NCHW_inv(xnp) # move channel if needed (TF) x = uft.to_tensor(xnp) uft.build_layer(gs, (28, 28, 4)) y = gs(x).numpy() @@ -141,11 +141,12 @@ def test_GroupSort(group_size, img, expected): y_t = np.repeat( np.expand_dims(np.repeat(np.expand_dims(y_tnp, -1), 28, -1), -1), 28, -1 ) - y_t = uft.to_NCHW_inv(y_t) # move channel if needed (TF) + y_t = uft.to_NCHW_inv(y_t) # move channel if needed (TF) # print("aaa",y_t.shape, y_t) # print("aaab",y.shape, y) np.testing.assert_equal(y, y_t) + @pytest.mark.parametrize("group_size", [2, 4]) def test_GroupSort_idempotence(group_size): gs = uft.get_instance_framework(GroupSort, {"group_size": group_size}) diff --git a/tests/test_unconstrained_layers.py b/tests/test_unconstrained_layers.py index 69b829d..f5834c3 100644 --- a/tests/test_unconstrained_layers.py +++ b/tests/test_unconstrained_layers.py @@ -37,8 +37,8 @@ def compare(x, x_ref, index_x=[], index_x_ref=[]): """Compare a tensor and its padded version, based on index_x and ref.""" - x = uft.to_numpy(uft.to_NCHW(x)) - x_ref = uft.to_numpy(uft.to_NCHW(x_ref)) + x = uft.to_NCHW(uft.to_numpy(x)) + x_ref = uft.to_NCHW(uft.to_numpy(x_ref)) x_cropped = x[:, :, index_x[0] : index_x[1], index_x[3] : index_x[4]][ :, :, :: index_x[2], :: index_x[5] ] diff --git a/tests/utils_framework.py b/tests/utils_framework.py index cc33ebd..3107b06 100644 --- a/tests/utils_framework.py +++ b/tests/utils_framework.py @@ -527,9 +527,11 @@ def to_framework_channel(x): def to_NCHW(x): return x + def to_NCHW_inv(x): return x + def get_NCHW(x): return (x.shape[0], x.shape[1], x.shape[2], x.shape[3])