Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend stack and broadcast to #561

Merged
merged 14 commits into from
Mar 7, 2025
8 changes: 8 additions & 0 deletions mrmustard/math/backend_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def block(self, blocks: list[list[jnp.ndarray]], axes=(-2, -1)) -> jnp.ndarray:
rows = [self.concat(row, axis=axes[1]) for row in blocks]
return self.concat(rows, axis=axes[0])

@partial(jax.jit, static_argnames=["shape"])
def broadcast_to(self, array: jnp.ndarray, shape: tuple[int]) -> jnp.ndarray:
return jnp.broadcast_to(array, shape)

@partial(jax.jit, static_argnames=["axis"])
def prod(self, x: jnp.ndarray, axis: int | None):
return jnp.prod(x, axis=axis)
Expand Down Expand Up @@ -417,6 +421,10 @@ def solve(self, matrix: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray:
def sqrt(self, x: jnp.ndarray, dtype=None) -> jnp.ndarray:
return jnp.sqrt(self.cast(x, dtype))

@partial(jax.jit, static_argnames=["axis"])
def stack(self, arrays: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
return jnp.stack(arrays, axis=axis)

@jax.jit
def kron(self, tensor1: jnp.ndarray, tensor2: jnp.ndarray):
return jnp.kron(tensor1, tensor2)
Expand Down
25 changes: 25 additions & 0 deletions mrmustard/math/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,18 @@ def block(self, blocks: list[list[Tensor]], axes=(-2, -1)) -> Tensor:
"""
return self._apply("block", (blocks, axes))

def broadcast_to(self, array: Tensor, shape: tuple[int]) -> Tensor:
r"""Broadcasts an array to a new shape.

Args:
array: The array to broadcast.
shape: The shape to broadcast to.

Returns:
The broadcasted array.
"""
return self._apply("broadcast_to", (array, shape))

def cast(self, array: Tensor, dtype=None) -> Tensor:
r"""Casts ``array`` to ``dtype``.

Expand Down Expand Up @@ -1145,6 +1157,19 @@ def sqrtm(self, tensor: Tensor, dtype=None) -> Tensor:
The square root of ``x``"""
return self._apply("sqrtm", (tensor, dtype))

def stack(self, arrays: Sequence[Tensor], axis: int = 0) -> Tensor:
r"""Stack arrays in sequence along a new axis.

Args:
arrays: Sequence of tensors to stack
axis: The axis along which to stack the arrays

Returns:
The stacked array
"""
arrays = self.astensor(arrays)
return self._apply("stack", (arrays, axis))

def sum(self, array: Tensor, axis: int | Sequence[int] | None = None):
r"""The sum of array.

Expand Down
10 changes: 6 additions & 4 deletions mrmustard/math/backend_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ def abs(self, array: np.ndarray) -> np.ndarray:
return np.abs(array)

def allclose(self, array1: np.array, array2: np.array, atol: float, rtol: float) -> bool:
array1 = self.asnumpy(array1)
array2 = self.asnumpy(array2)
if array1.shape != array2.shape:
raise ValueError("Cannot compare arrays of different shapes.")
return np.allclose(array1, array2, atol=atol, rtol=rtol)

def any(self, array: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -104,6 +100,9 @@ def block(self, blocks: list[list[np.ndarray]], axes=(-2, -1)) -> np.ndarray:
rows = [self.concat(row, axis=axes[1]) for row in blocks]
return self.concat(rows, axis=axes[0])

def broadcast_to(self, array: np.ndarray, shape: tuple[int]) -> np.ndarray:
return np.broadcast_to(array, shape)

def block_diag(self, *blocks: list[np.ndarray]) -> np.ndarray:
return sp.linalg.block_diag(*blocks)

Expand Down Expand Up @@ -400,6 +399,9 @@ def sort(self, array: np.ndarray, axis: int = -1) -> np.ndarray:
def sqrt(self, x: np.ndarray, dtype=None) -> np.ndarray:
return np.sqrt(self.cast(x, dtype))

def stack(self, arrays: np.ndarray, axis: int = 0) -> np.ndarray:
return np.stack(arrays, axis=axis)

def sum(self, array: np.ndarray, axis: int | tuple[int] | None = None):
return np.sum(array, axis=axis)

Expand Down
10 changes: 6 additions & 4 deletions mrmustard/math/backend_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ def abs(self, array: tf.Tensor) -> tf.Tensor:
return tf.abs(array)

def allclose(self, array1: np.array, array2: np.array, atol: float, rtol: float) -> bool:
array1 = self.astensor(array1)
array2 = self.astensor(array2)
if array1.shape != array2.shape:
raise ValueError("Cannot compare arrays of different shapes.")
return tf.experimental.numpy.allclose(array1, array2, atol=atol, rtol=rtol)

def any(self, array: tf.Tensor) -> tf.Tensor:
Expand Down Expand Up @@ -110,6 +106,9 @@ def block(self, blocks: list[list[tf.Tensor]], axes=(-2, -1)) -> tf.Tensor:
rows = [self.concat(row, axis=axes[1]) for row in blocks]
return self.concat(rows, axis=axes[0])

def broadcast_to(self, array: tf.Tensor, shape: tuple[int]) -> tf.Tensor:
return tf.broadcast_to(array, shape)

def boolean_mask(self, tensor: tf.Tensor, mask: tf.Tensor) -> Tensor:
return tf.boolean_mask(tensor, mask)

Expand Down Expand Up @@ -345,6 +344,9 @@ def sort(self, array: tf.Tensor, axis: int = -1) -> tf.Tensor:
def sqrt(self, x: tf.Tensor, dtype=None) -> tf.Tensor:
return tf.sqrt(self.cast(x, dtype))

def stack(self, arrays: tf.Tensor, axis: int = 0) -> tf.Tensor:
return tf.stack(arrays, axis=axis)

def sum(self, array: tf.Tensor, axis: int | tuple[int] | None = None):
return tf.reduce_sum(array, axis)

Expand Down
24 changes: 21 additions & 3 deletions tests/test_math/test_backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def test_allclose_error(self):
arr1 = math.astensor([1, 2, 3])
arr2 = math.astensor([[1, 2], [1, 2]])

if math.backend_name != "jax":
with pytest.raises(ValueError, match="Cannot compare"):
if math.backend_name == "numpy":
with pytest.raises(ValueError, match="could not be broadcast"):
math.allclose(arr1, arr2)
else:
elif math.backend_name == "jax":
with pytest.raises(ValueError, match="Incompatible shapes"):
math.allclose(arr2, arr1)

Expand Down Expand Up @@ -262,6 +262,14 @@ def test_block_diag(self):
assert R.shape == (8, 8)
assert math.allclose(math.block([[I, O], [O, 1j * I]]), R)

def test_broadcast_to(self):
r"""
Tests the ``broadcast_to`` method.
"""
arr = math.astensor([1, 2, 3])
res = math.broadcast_to(arr, (3, 3))
assert math.allclose(res, math.astensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]))

@pytest.mark.parametrize("t", types)
def test_cast(self, t):
r"""
Expand Down Expand Up @@ -674,6 +682,16 @@ def test_sqrtm(self):
res = math.asnumpy(math.sqrtm(arr))
assert math.allclose(res, 2 * np.eye(3))

def test_stack(self):
r"""
Tests the ``stack`` method.
"""
arr1 = np.eye(3)
arr2 = 2 * np.eye(3)
res = math.asnumpy(math.stack([arr1, arr2], axis=0))
exp = np.stack([arr1, arr2], axis=0)
assert np.allclose(res, exp)

def test_sum(self):
r"""
Tests the ``sum`` method.
Expand Down