From 960d3c992b8ffb1e092d2eb35ec2458f15ca2657 Mon Sep 17 00:00:00 2001 From: Saanidhyavats Date: Mon, 26 Aug 2024 21:48:31 -0400 Subject: [PATCH 1/9] Solves issue #1240 --- python/mlx/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index 14b23a41ec..8081620060 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -111,7 +111,7 @@ def tree_map_with_path( return fn(path, tree, *rest) -def tree_flatten(tree, prefix="", is_leaf=None): +def tree_flatten(tree: Any, prefix: str = "", is_leaf: Callable = None) -> Any: """Flattens a Python tree to a list of key, value tuples. The keys are using the dot notation to define trees of arbitrary depth and @@ -155,7 +155,7 @@ def tree_flatten(tree, prefix="", is_leaf=None): return [(prefix[1:], tree)] -def tree_unflatten(tree): +def tree_unflatten(tree: Any) -> Any: """Recreate a Python tree from its flat representation. .. code-block:: python From 3420163a8a127d308af72089a895d49013b50ef4 Mon Sep 17 00:00:00 2001 From: Saanidhyavats Date: Fri, 25 Oct 2024 00:28:41 -0400 Subject: [PATCH 2/9] Added maxpool and pool3d class --- python/mlx/nn/layers/pooling.py | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 93ae4d8c2e..1908d39f65 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -332,3 +332,38 @@ def __init__( padding: Optional[Union[int, Tuple[int, int]]] = 0, ): super().__init__(mx.mean, 0, kernel_size, stride, padding) + + +class _Pool3d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 3 integers" + kernel_size = _value_or_list( + kernel_size, 3, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 3, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 3, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + +class MaxPool3d(_Pool3d): + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) From bd636721e00417bd365710b01cf2a82cc72458c0 Mon Sep 17 00:00:00 2001 From: cvnad1 Date: Thu, 24 Oct 2024 21:39:35 -0700 Subject: [PATCH 3/9] avgpooling3d --- python/mlx/nn/layers/pooling.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 1908d39f65..02ad91f213 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -367,3 +367,14 @@ def __init__( padding: Optional[Union[int, Tuple[int, int, int]]] = 0, ): super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool3d(_Pool3d): + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) From edec26f538f7d85b6bbc994bd623de8c407048c8 Mon Sep 17 00:00:00 2001 From: Saanidhyavats Date: Fri, 25 Oct 2024 00:28:41 -0400 Subject: [PATCH 4/9] Added maxpool and pool3d class --- python/mlx/nn/layers/pooling.py | 35 ++++++++++++++++++ python/tests/test_nn.py | 63 +++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 93ae4d8c2e..1908d39f65 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -332,3 +332,38 @@ def __init__( padding: Optional[Union[int, Tuple[int, int]]] = 0, ): super().__init__(mx.mean, 0, kernel_size, stride, padding) + + +class _Pool3d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 3 integers" + kernel_size = _value_or_list( + kernel_size, 3, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 3, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 3, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + +class MaxPool3d(_Pool3d): + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 38659625f5..5caa26b6ca 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1569,6 +1569,69 @@ def test_pooling(self): str(nn.AvgPool2d(kernel_size=(1, 2), stride=2, padding=(1, 2))), "AvgPool2d(kernel_size=(1, 2), stride=(2, 2), padding=(1, 2))", ) + # Test 3d pooling + x = mx.array( + [ + [ + [ + [[0, 1, 2], [3, 4, 5], [6, 7, 8]], + [[9, 10, 11], [12, 13, 14], [15, 16, 17]], + [[18, 19, 20], [21, 22, 23], [24, 25, 26]], + ], + [ + [[27, 28, 29], [30, 31, 32], [33, 34, 35]], + [[36, 37, 38], [39, 40, 41], [42, 43, 44]], + [[45, 46, 47], [48, 49, 50], [51, 52, 53]], + ], + ] + ] + ) + expected_max_pool_output_no_padding_stride_1 = [ + [[[[39, 40, 41], [42, 43, 44]], [[48, 49, 50], [51, 52, 53]]]] + ] + + expected_max_pool_output_no_padding_stride_2 = [[[[[39, 40, 41]]]]] + expected_max_pool_output_padding_1 = [ + [ + [[[0, 1, 2], [6, 7, 8]], [[18, 19, 20], [24, 25, 26]]], + [[[27, 28, 29], [33, 34, 35]], [[45, 46, 47], [51, 52, 53]]], + ] + ] + expected_irregular_max_pool_output = [ + [ + [[[9, 10, 11], [12, 13, 14], [15, 16, 17]]], + [[[36, 37, 38], [39, 40, 41], [42, 43, 44]]], + ] + ] + + self.assertTrue( + np.array_equal( + nn.MaxPool3d(kernel_size=2, stride=1, padding=0)(x), + expected_max_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool3d(kernel_size=2, stride=2, padding=0)(x), + expected_max_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool3d(kernel_size=2, stride=2, padding=1)(x), + expected_max_pool_output_padding_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.MaxPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x), + expected_irregular_max_pool_output, + ) + ) + self.assertEqual( + str(nn.MaxPool3d(kernel_size=3, stride=3, padding=2)), + "MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", + ) def test_set_dtype(self): def assert_dtype(layer, dtype): From a058f641552f24c86613b3c0954f69fd168b677a Mon Sep 17 00:00:00 2001 From: cvnad1 Date: Thu, 24 Oct 2024 21:39:35 -0700 Subject: [PATCH 5/9] avgpooling3d --- python/mlx/nn/layers/pooling.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 1908d39f65..02ad91f213 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -367,3 +367,14 @@ def __init__( padding: Optional[Union[int, Tuple[int, int, int]]] = 0, ): super().__init__(mx.max, -float("inf"), kernel_size, stride, padding) + + +class AvgPool3d(_Pool3d): + + def __init__( + self, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + super().__init__(mx.mean, 0, kernel_size, stride, padding) From 9e5a19fa67e07f16d7509ff175543e29e1edf0e3 Mon Sep 17 00:00:00 2001 From: Saanidhyavats Date: Thu, 31 Oct 2024 01:13:21 -0400 Subject: [PATCH 6/9] Documentation --- python/mlx/nn/layers/pooling.py | 84 ++++++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 22 deletions(-) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index 02ad91f213..6e9bfd6872 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -158,6 +158,30 @@ def __init__( super().__init__(pooling_function, kernel_size, stride, padding, padding_value) +class _Pool3d(_Pool): + def __init__( + self, + pooling_function, + padding_value, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Optional[Union[int, Tuple[int, int, int]]] = None, + padding: Optional[Union[int, Tuple[int, int, int]]] = 0, + ): + class_name = type(self).__name__ + msg = "[{}] '{}' must be an integer or a tuple containing 3 integers" + kernel_size = _value_or_list( + kernel_size, 3, msg.format(class_name, "kernel_size") + ) + if stride is not None: + stride = _value_or_list(stride, 3, msg.format(class_name, "stride")) + else: + stride = kernel_size + padding = _value_or_list(padding, 3, msg.format(class_name, "padding")) + padding = [(p, p) for p in padding] + + super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + + class MaxPool1d(_Pool1d): r"""Applies 1-dimensional max pooling. @@ -334,31 +358,47 @@ def __init__( super().__init__(mx.mean, 0, kernel_size, stride, padding) -class _Pool3d(_Pool): - def __init__( - self, - pooling_function, - padding_value, - kernel_size: Union[int, Tuple[int, int, int]], - stride: Optional[Union[int, Tuple[int, int, int]]] = None, - padding: Optional[Union[int, Tuple[int, int, int]]] = 0, - ): - class_name = type(self).__name__ - msg = "[{}] '{}' must be an integer or a tuple containing 3 integers" - kernel_size = _value_or_list( - kernel_size, 3, msg.format(class_name, "kernel_size") - ) - if stride is not None: - stride = _value_or_list(stride, 3, msg.format(class_name, "stride")) - else: - stride = kernel_size - padding = _value_or_list(padding, 3, msg.format(class_name, "padding")) - padding = [(p, p) for p in padding] +class MaxPool3d(_Pool3d): + """ + Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is + :math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out}, + H_{out}, W_{out}, C)`, given by: - super().__init__(pooling_function, kernel_size, stride, padding, padding_value) + .. math:: + \begin{aligned} + \text{out}(N_i, d, h, w, C_j) = & \max_{l=0, \ldots, k_D-1} \max_{m=0, \ldots, k_H-1} \max_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times d + l, + \text{stride[1]} \times h + m, + \text{stride[2]} \times w + n, C_j), + \end{aligned} + where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`. -class MaxPool3d(_Pool3d): + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for the depth, + height and width axis; + - a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used + for the depth axis, the second ``int`` for the height axis, and the third + ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int, int)): The size of the pooling window. + stride (int or tuple(int, int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int, int), optional): How much negative infinity + padding to apply to the input. The padding is applied on both sides + of the depth, height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) + >>> pool = nn.MaxPool3d(kernel_size=2, stride=2) + >>> pool(x) + """ def __init__( self, From 683b50d58e4a08d42b598efe47e6428c0972b9a5 Mon Sep 17 00:00:00 2001 From: cvnad1 Date: Wed, 30 Oct 2024 22:36:42 -0700 Subject: [PATCH 7/9] Added tests and comments for AvgPool3d --- python/mlx/nn/layers/pooling.py | 39 ++++++++++++++++++++++++ python/tests/test_nn.py | 54 +++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/python/mlx/nn/layers/pooling.py b/python/mlx/nn/layers/pooling.py index ce3e750413..dd5c676968 100644 --- a/python/mlx/nn/layers/pooling.py +++ b/python/mlx/nn/layers/pooling.py @@ -410,7 +410,46 @@ def __init__( class AvgPool3d(_Pool3d): + """ + Assuming an input of shape :math:`(N, D, H, W, C)` and ``kernel_size`` is + :math:`(k_D, k_H, k_W)`, the output is a tensor of shape :math:`(N, D_{out}, + H_{out}, W_{out}, C)`, given by: + .. math:: + \begin{aligned} + \text{out}(N_i, d, h, w, C_j) = & \frac{1}{k_D k_H k_W} \sum_{l=0, \ldots, k_D-1} \sum_{m=0, \ldots, k_H-1} \sum_{n=0, \ldots, k_W-1} \\ + & \text{input}(N_i, \text{stride[0]} \times d + l, + \text{stride[1]} \times h + m, + \text{stride[2]} \times w + n, C_j), + \end{aligned} + + where :math:`D_{out} = \left\lfloor\frac{D + 2 * \text{padding[0]} - \text{kernel\_size[0]}}{\text{stride[0]}}\right\rfloor + 1`, + :math:`H_{out} = \left\lfloor\frac{H + 2 * \text{padding[1]} - \text{kernel\_size[1]}}{\text{stride[1]}}\right\rfloor + 1`, + :math:`W_{out} = \left\lfloor\frac{W + 2 * \text{padding[2]} - \text{kernel\_size[2]}}{\text{stride[2]}}\right\rfloor + 1`. + + The parameters ``kernel_size``, ``stride``, ``padding``, can either be: + + - a single ``int`` -- in which case the same value is used for the depth, + height and width axis; + - a ``tuple`` of three ``int`` s -- in which case, the first ``int`` is used + for the depth axis, the second ``int`` for the height axis, and the third + ``int`` for the width axis. + + Args: + kernel_size (int or tuple(int, int, int)): The size of the pooling window. + stride (int or tuple(int, int, int), optional): The stride of the pooling + window. Default: ``kernel_size``. + padding (int or tuple(int, int, int), optional): How much zero + padding to apply to the input. The padding is applied on both sides + of the depth, height and width axis. Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn.layers as nn + >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4)) + >>> pool = nn.AvgPool3d(kernel_size=2, stride=2) + >>> pool(x) + """ def __init__( self, kernel_size: Union[int, Tuple[int, int, int]], diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index a3e161139e..d38ec1d36a 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -1647,6 +1647,60 @@ def test_pooling(self): "MaxPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", ) + expected_avg_pool_output_no_padding_stride_1 = [[[[[19.5, 20.5, 21.5], + [22.5, 23.5, 24.5]], + [[28.5, 29.5, 30.5], + [31.5, 32.5, 33.5]]]] + ] + + expected_avg_pool_output_no_padding_stride_2 = [[[[[19.5, 20.5, 21.5]]]]] + expected_avg_pool_output_padding_1 = [ + [[[[0, 0.125, 0.25], + [1.125, 1.375, 1.625]], + [[3.375, 3.625, 3.875], + [9, 9.5, 10]]], + [[[3.375, 3.5, 3.625], + [7.875, 8.125, 8.375]], + [[10.125, 10.375, 10.625], + [22.5, 23, 23.5]]]] + ] + expected_irregular_avg_pool_output = [[[[[4.5, 5.5, 6.5], + [7.5, 8.5, 9.5], + [10.5, 11.5, 12.5]]], + [[[31.5, 32.5, 33.5], + [34.5, 35.5, 36.5], + [37.5, 38.5, 39.5]]]] + ] + + self.assertTrue( + np.array_equal( + nn.AvgPool3d(kernel_size=2, stride=1, padding=0)(x), + expected_avg_pool_output_no_padding_stride_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool3d(kernel_size=2, stride=2, padding=0)(x), + expected_avg_pool_output_no_padding_stride_2, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool3d(kernel_size=2, stride=2, padding=1)(x), + expected_avg_pool_output_padding_1, + ) + ) + self.assertTrue( + np.array_equal( + nn.AvgPool3d(kernel_size=(1, 2, 1), stride=(1, 2, 1))(x), + expected_irregular_avg_pool_output, + ) + ) + self.assertEqual( + str(nn.AvgPool3d(kernel_size=3, stride=3, padding=2)), + "AvgPool3d(kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(2, 2, 2))", + ) + def test_set_dtype(self): def assert_dtype(layer, dtype): for k, v in tree_flatten(layer.parameters()): From 81ac999e832a17c23d9ff5b8f2e615490e54e782 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 19 Nov 2024 16:37:33 -0800 Subject: [PATCH 8/9] Remove extra empty lines --- python/mlx/utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/mlx/utils.py b/python/mlx/utils.py index cd2d0822cb..6754232a6e 100644 --- a/python/mlx/utils.py +++ b/python/mlx/utils.py @@ -111,11 +111,9 @@ def tree_map_with_path( return fn(path, tree, *rest) - def tree_flatten( tree: Any, prefix: str = "", is_leaf: Optional[Callable] = None ) -> Any: - """Flattens a Python tree to a list of key, value tuples. The keys are using the dot notation to define trees of arbitrary depth and @@ -159,9 +157,7 @@ def tree_flatten( return [(prefix[1:], tree)] - def tree_unflatten(tree: List[Tuple[str, Any]]) -> Any: - """Recreate a Python tree from its flat representation. .. code-block:: python From 56c0faebd77b2b510626219b610a47193411ee5f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 19 Nov 2024 16:41:31 -0800 Subject: [PATCH 9/9] Expose 3d pooling to the layers --- python/mlx/nn/layers/__init__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 3cf5e33a81..c1d89fed9f 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -70,7 +70,14 @@ LayerNorm, RMSNorm, ) -from mlx.nn.layers.pooling import AvgPool1d, AvgPool2d, MaxPool1d, MaxPool2d +from mlx.nn.layers.pooling import ( + AvgPool1d, + AvgPool2d, + AvgPool3d, + MaxPool1d, + MaxPool2d, + MaxPool3d, +) from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedEmbedding, QuantizedLinear, quantize from mlx.nn.layers.recurrent import GRU, LSTM, RNN