From 779f9072e7bb5eee788b3c73b2a992cc8ac2f9a2 Mon Sep 17 00:00:00 2001 From: David Koski Date: Mon, 10 Feb 2025 15:18:55 -0800 Subject: [PATCH] add missing roll function --- .../Documentation.docc/Organization/shapes.md | 2 + Source/MLX/Ops.swift | 48 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/Source/MLX/Documentation.docc/Organization/shapes.md b/Source/MLX/Documentation.docc/Organization/shapes.md index 19e3494c..d1eff5d5 100644 --- a/Source/MLX/Documentation.docc/Organization/shapes.md +++ b/Source/MLX/Documentation.docc/Organization/shapes.md @@ -41,6 +41,8 @@ and ``MLXArray/shape`` of the dimensions without changing the number of elements - ``flattened(_:start:end:stream:)`` - ``reshaped(_:_:stream:)-5x3y0`` - ``squeezed(_:axes:stream:)`` +- ``roll(_:shift:axis:stream:)`` +- ``roll(_:shift:axes:stream:)`` ### MLXArray Shape Methods (Change Size) diff --git a/Source/MLX/Ops.swift b/Source/MLX/Ops.swift index 646e0584..d3e44517 100644 --- a/Source/MLX/Ops.swift +++ b/Source/MLX/Ops.swift @@ -2201,6 +2201,54 @@ public func remainder( return MLXArray(result) } +/// Roll array elements along a given axis. +/// +/// Elements that are rolled beyond the end of the array are introduced at the beggining and vice-versa. +/// +/// - Parameters: +/// - a: input array +/// - shift: The number of places by which elements +/// are shifted. If positive the array is rolled to the right, if +/// negative it is rolled to the left. +/// - axis: the axis along which to roll the elements +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +public func roll(_ a: MLXArray, shift: Int, axis: Int, stream: StreamOrDevice = .default) + -> MLXArray +{ + var result = mlx_array_new() + mlx_roll(&result, a.ctx, shift.int32, [axis.int32], 1, stream.ctx) + return MLXArray(result) +} + +/// Roll array elements along a given axis. +/// +/// Elements that are rolled beyond the end of the array are introduced at the beggining and vice-versa. +/// +/// - Parameters: +/// - a: input array +/// - shift: The number of places by which elements +/// are shifted. If positive the array is rolled to the right, if +/// negative it is rolled to the left. +/// - axes: the axes along which to roll the elements, or all if omitted +/// - stream: stream or device to evaluate on +/// +/// ### See Also +/// - +public func roll(_ a: MLXArray, shift: Int, axes: [Int]? = nil, stream: StreamOrDevice = .default) + -> MLXArray +{ + var result = mlx_array_new() + if let axes { + mlx_roll(&result, a.ctx, shift.int32, axes.asInt32, axes.count, stream.ctx) + } else { + mlx_roll_all(&result, a.ctx, shift.int32, stream.ctx) + } + return MLXArray(result) +} + /// Save array to a binary file in `.npy`format. /// /// - Parameters: