Skip to content

Commit 71c755d

Browse files
apchytrziofil
authored andcommitted
Backend stack and broadcast to (#561)
1 parent c067c51 commit 71c755d

File tree

5 files changed

+56
-19
lines changed

5 files changed

+56
-19
lines changed

mrmustard/math/backend_jax.py

+8
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def block(self, blocks: list[list[jnp.ndarray]], axes=(-2, -1)) -> jnp.ndarray:
8787
rows = [self.concat(row, axis=axes[1]) for row in blocks]
8888
return self.concat(rows, axis=axes[0])
8989

90+
@partial(jax.jit, static_argnames=["shape"])
91+
def broadcast_to(self, array: jnp.ndarray, shape: tuple[int]) -> jnp.ndarray:
92+
return jnp.broadcast_to(array, shape)
93+
9094
@partial(jax.jit, static_argnames=["axis"])
9195
def prod(self, x: jnp.ndarray, axis: int | None):
9296
return jnp.prod(x, axis=axis)
@@ -417,6 +421,10 @@ def solve(self, matrix: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray:
417421
def sqrt(self, x: jnp.ndarray, dtype=None) -> jnp.ndarray:
418422
return jnp.sqrt(self.cast(x, dtype))
419423

424+
@partial(jax.jit, static_argnames=["axis"])
425+
def stack(self, arrays: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
426+
return jnp.stack(arrays, axis=axis)
427+
420428
@jax.jit
421429
def kron(self, tensor1: jnp.ndarray, tensor2: jnp.ndarray):
422430
return jnp.kron(tensor1, tensor2)

mrmustard/math/backend_manager.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,18 @@ def block(self, blocks: list[list[Tensor]], axes=(-2, -1)) -> Tensor:
387387
"""
388388
return self._apply("block", (blocks, axes))
389389

390+
def broadcast_to(self, array: Tensor, shape: tuple[int]) -> Tensor:
391+
r"""Broadcasts an array to a new shape.
392+
393+
Args:
394+
array: The array to broadcast.
395+
shape: The shape to broadcast to.
396+
397+
Returns:
398+
The broadcasted array.
399+
"""
400+
return self._apply("broadcast_to", (array, shape))
401+
390402
def cast(self, array: Tensor, dtype=None) -> Tensor:
391403
r"""Casts ``array`` to ``dtype``.
392404
@@ -1145,17 +1157,18 @@ def sqrtm(self, tensor: Tensor, dtype=None) -> Tensor:
11451157
The square root of ``x``"""
11461158
return self._apply("sqrtm", (tensor, dtype))
11471159

1148-
def stack(self, values: Tensor, axis: int = 0) -> Tensor:
1149-
r"""Stacks a list of tensors along a new axis.
1160+
def stack(self, arrays: Sequence[Tensor], axis: int = 0) -> Tensor:
1161+
r"""Stack arrays in sequence along a new axis.
11501162
11511163
Args:
1152-
values: A list of tensors to stack.
1153-
axis: The axis along which to introduce the new dimension.
1164+
arrays: Sequence of tensors to stack
1165+
axis: The axis along which to stack the arrays
11541166
11551167
Returns:
1156-
The stacked tensor.
1168+
The stacked array
11571169
"""
1158-
return self._apply("stack", (values, axis))
1170+
arrays = self.astensor(arrays)
1171+
return self._apply("stack", (arrays, axis))
11591172

11601173
def sum(self, array: Tensor, axis: int | Sequence[int] | None = None):
11611174
r"""The sum of array.

mrmustard/math/backend_numpy.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,6 @@ def abs(self, array: np.ndarray) -> np.ndarray:
7070
return np.abs(array)
7171

7272
def allclose(self, array1: np.array, array2: np.array, atol: float, rtol: float) -> bool:
73-
array1 = self.asnumpy(array1)
74-
array2 = self.asnumpy(array2)
75-
if array1.shape != array2.shape:
76-
raise ValueError("Cannot compare arrays of different shapes.")
7773
return np.allclose(array1, array2, atol=atol, rtol=rtol)
7874

7975
def any(self, array: np.ndarray) -> np.ndarray:
@@ -104,6 +100,9 @@ def block(self, blocks: list[list[np.ndarray]], axes=(-2, -1)) -> np.ndarray:
104100
rows = [self.concat(row, axis=axes[1]) for row in blocks]
105101
return self.concat(rows, axis=axes[0])
106102

103+
def broadcast_to(self, array: np.ndarray, shape: tuple[int]) -> np.ndarray:
104+
return np.broadcast_to(array, shape)
105+
107106
def block_diag(self, *blocks: list[np.ndarray]) -> np.ndarray:
108107
return sp.linalg.block_diag(*blocks)
109108

@@ -400,8 +399,8 @@ def sort(self, array: np.ndarray, axis: int = -1) -> np.ndarray:
400399
def sqrt(self, x: np.ndarray, dtype=None) -> np.ndarray:
401400
return np.sqrt(self.cast(x, dtype))
402401

403-
def stack(self, values: list[np.ndarray], axis: int = 0) -> np.ndarray:
404-
return np.stack(values, axis=axis)
402+
def stack(self, arrays: np.ndarray, axis: int = 0) -> np.ndarray:
403+
return np.stack(arrays, axis=axis)
405404

406405
def sum(self, array: np.ndarray, axis: int | tuple[int] | None = None):
407406
return np.sum(array, axis=axis)

mrmustard/math/backend_tensorflow.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ def abs(self, array: tf.Tensor) -> tf.Tensor:
7171
return tf.abs(array)
7272

7373
def allclose(self, array1: np.array, array2: np.array, atol: float, rtol: float) -> bool:
74-
array1 = self.astensor(array1)
75-
array2 = self.astensor(array2)
76-
if array1.shape != array2.shape:
77-
raise ValueError("Cannot compare arrays of different shapes.")
7874
return tf.experimental.numpy.allclose(array1, array2, atol=atol, rtol=rtol)
7975

8076
def any(self, array: tf.Tensor) -> tf.Tensor:
@@ -110,6 +106,9 @@ def block(self, blocks: list[list[tf.Tensor]], axes=(-2, -1)) -> tf.Tensor:
110106
rows = [self.concat(row, axis=axes[1]) for row in blocks]
111107
return self.concat(rows, axis=axes[0])
112108

109+
def broadcast_to(self, array: tf.Tensor, shape: tuple[int]) -> tf.Tensor:
110+
return tf.broadcast_to(array, shape)
111+
113112
def boolean_mask(self, tensor: tf.Tensor, mask: tf.Tensor) -> Tensor:
114113
return tf.boolean_mask(tensor, mask)
115114

tests/test_math/test_backend_manager.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def test_allclose_error(self):
8989
arr1 = math.astensor([1, 2, 3])
9090
arr2 = math.astensor([[1, 2], [1, 2]])
9191

92-
if math.backend_name != "jax":
93-
with pytest.raises(ValueError, match="Cannot compare"):
92+
if math.backend_name == "numpy":
93+
with pytest.raises(ValueError, match="could not be broadcast"):
9494
math.allclose(arr1, arr2)
95-
else:
95+
elif math.backend_name == "jax":
9696
with pytest.raises(ValueError, match="Incompatible shapes"):
9797
math.allclose(arr2, arr1)
9898

@@ -262,6 +262,14 @@ def test_block_diag(self):
262262
assert R.shape == (8, 8)
263263
assert math.allclose(math.block([[I, O], [O, 1j * I]]), R)
264264

265+
def test_broadcast_to(self):
266+
r"""
267+
Tests the ``broadcast_to`` method.
268+
"""
269+
arr = math.astensor([1, 2, 3])
270+
res = math.broadcast_to(arr, (3, 3))
271+
assert math.allclose(res, math.astensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]))
272+
265273
@pytest.mark.parametrize("t", types)
266274
def test_cast(self, t):
267275
r"""
@@ -674,6 +682,16 @@ def test_sqrtm(self):
674682
res = math.asnumpy(math.sqrtm(arr))
675683
assert math.allclose(res, 2 * np.eye(3))
676684

685+
def test_stack(self):
686+
r"""
687+
Tests the ``stack`` method.
688+
"""
689+
arr1 = np.eye(3)
690+
arr2 = 2 * np.eye(3)
691+
res = math.asnumpy(math.stack([arr1, arr2], axis=0))
692+
exp = np.stack([arr1, arr2], axis=0)
693+
assert np.allclose(res, exp)
694+
677695
def test_sum(self):
678696
r"""
679697
Tests the ``sum`` method.

0 commit comments

Comments
 (0)