Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batching triples.py and symplectics.py #563

Open
wants to merge 84 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 81 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
23cfb15
beamsplitter
apchytr Feb 25, 2025
4d589dd
stack
apchytr Feb 25, 2025
966bede
bsgate working
apchytr Feb 25, 2025
f508253
separating out to triples_batched
apchytr Feb 25, 2025
192e497
broadcast to
apchytr Feb 26, 2025
7fac07b
two mode squeezing
apchytr Feb 26, 2025
3ca743a
squeezing gate
apchytr Feb 26, 2025
12e46ba
displacement gate
apchytr Feb 26, 2025
8c49858
unitaries done
apchytr Feb 26, 2025
0d2bb36
some pure states
apchytr Feb 26, 2025
f0ab28d
thermal state
apchytr Feb 26, 2025
8f12d61
progress
apchytr Feb 26, 2025
11cd5e2
two mode sq vac
apchytr Feb 26, 2025
07ed8d5
fix
apchytr Feb 26, 2025
9ec3617
rem reshape
apchytr Feb 26, 2025
7692556
rem import
apchytr Feb 26, 2025
c30d4a3
attenuator
apchytr Feb 27, 2025
d8f29f0
amplifier
apchytr Feb 27, 2025
5d96793
fock damping
apchytr Feb 27, 2025
217da59
starting symplectics
apchytr Feb 27, 2025
735bbb0
starting symplectics
apchytr Feb 27, 2025
669d62c
symplectic progress
apchytr Feb 27, 2025
88596ca
Merge branch 'develop' into npStack
apchytr Feb 27, 2025
63944fa
progress
apchytr Feb 27, 2025
fff6550
jax broken
apchytr Feb 28, 2025
9684521
jax fix
apchytr Feb 28, 2025
b929850
jit broadcast
apchytr Feb 28, 2025
99f67bb
jit stack
apchytr Feb 28, 2025
4691aa6
Merge branch 'develop' into npStack
apchytr Feb 28, 2025
badb4d4
mzgate
apchytr Feb 28, 2025
2a4d96f
symplectic2Au
apchytr Feb 28, 2025
a763575
attenuator_kraus_Abc
apchytr Feb 28, 2025
3d1bf87
bargmann_to_quadrature_Abc
apchytr Feb 28, 2025
8aa9cea
gket_state_Abc
apchytr Feb 28, 2025
c227abe
pgate and symplectics.py
apchytr Feb 28, 2025
d630020
symplectics done
apchytr Feb 28, 2025
fc10871
displacement_map_s_parametrized_Abc
apchytr Feb 28, 2025
9762b6b
displacement_map_s_parametrized_Abc
apchytr Feb 28, 2025
7cf7e16
gaussian_random_noise_Abc
apchytr Mar 1, 2025
0c93c95
todos
apchytr Mar 1, 2025
bd05de9
rem print
apchytr Mar 1, 2025
876d19e
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Mar 1, 2025
5a4f735
merge dev
apchytr Mar 1, 2025
cbd86e1
Merge branch 'npStack' of https://github.com/XanaduAI/MrMustard into …
apchytr Mar 1, 2025
57c1580
fix tests
apchytr Mar 1, 2025
378338f
rev
apchytr Mar 1, 2025
66c4375
rem batch slice
apchytr Mar 1, 2025
4710861
codefactor
apchytr Mar 1, 2025
7494617
small changes to joinAbc to help generalize
apchytr Mar 1, 2025
7361cb4
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Mar 5, 2025
561ea36
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Mar 5, 2025
677dfaa
progress toward integrating batching in
apchytr Mar 5, 2025
d4c83c5
states done
apchytr Mar 5, 2025
43a4b8f
CCs done
apchytr Mar 6, 2025
29ace61
everything up to test_aussian_integrals.py and test_triples.py
apchytr Mar 6, 2025
d4729b2
progress
apchytr Mar 6, 2025
b5d19b7
test_triples done
apchytr Mar 6, 2025
fe9900c
almost there
apchytr Mar 6, 2025
626be4b
broadcast to
apchytr Mar 6, 2025
9e99cac
ugh tensorflow
apchytr Mar 6, 2025
fa5994e
tensorflow working finally
apchytr Mar 6, 2025
64961c5
codefactor
apchytr Mar 6, 2025
1982e29
backend changes
apchytr Mar 6, 2025
f31b885
cleanup allclose
apchytr Mar 6, 2025
7a56cac
fix
apchytr Mar 6, 2025
0cde783
almost there
apchytr Mar 6, 2025
7881ba0
done gauss_int tests
apchytr Mar 6, 2025
2ba288e
triples_batched -> triples
apchytr Mar 6, 2025
b0ac492
tensorflow is great
apchytr Mar 6, 2025
50f11ae
tensorflow is great
apchytr Mar 6, 2025
81fc885
docs
apchytr Mar 6, 2025
2db8cd5
tf again
apchytr Mar 6, 2025
56d43a2
fix
apchytr Mar 6, 2025
0f9fbce
Merge branch 'npStack' of https://github.com/XanaduAI/MrMustard into …
apchytr Mar 6, 2025
0607873
compute batch size into utils
apchytr Mar 6, 2025
23e7cb9
tf
apchytr Mar 6, 2025
7bfddb2
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Mar 7, 2025
210ef24
rev
apchytr Mar 7, 2025
362056a
Update tests/test_physics/test_gaussian_integrals.py
apchytr Mar 7, 2025
f83f50f
broadcast_to dtype and astensor
apchytr Mar 7, 2025
5a8af62
Merge branch 'develop' into triplesBatched
apchytr Mar 7, 2025
ee89833
cr and starting some cleanup
apchytr Mar 26, 2025
9352bec
fix tf
apchytr Mar 27, 2025
fffe36f
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mrmustard/lab_dev/circuit_components_utils/b_to_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def __init__(
modes_in=modes,
modes_out=modes,
ansatz=PolyExpAnsatz.from_function(
fn=triples.bargmann_to_quadrature_Abc, n_modes=len(modes), phi=self.parameters.phi
fn=triples.bargmann_to_quadrature_Abc,
n_modes=len(modes),
phi=self.parameters.phi,
),
).representation
for w in self.representation.wires.input.wires:
Expand Down
2 changes: 1 addition & 1 deletion mrmustard/lab_dev/states/bargmann_eigenstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ def __init__(
self._representation = self.from_ansatz(
modes=(mode,),
ansatz=PolyExpAnsatz.from_function(
fn=triples.bargmann_eigenstate_Abc, x=self.parameters.alpha
fn=triples.bargmann_eigenstate_Abc, alpha=self.parameters.alpha
),
).representation
4 changes: 2 additions & 2 deletions mrmustard/lab_dev/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def from_bargmann(
>>> from mrmustard.physics.triples import coherent_state_Abc
>>> from mrmustard.lab_dev.states.ket import Ket

>>> modes = (0, 1)
>>> triple = coherent_state_Abc(x=[0.1, 0.2]) # parallel coherent states
>>> modes = (0,)
>>> triple = coherent_state_Abc(x=0.1)

>>> coh = Ket.from_bargmann(modes, triple)
>>> assert coh.modes == modes
Expand Down
4 changes: 3 additions & 1 deletion mrmustard/lab_dev/states/quadrature_eigenstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def __init__(
self._representation = self.from_ansatz(
modes=(mode,),
ansatz=PolyExpAnsatz.from_function(
fn=triples.quadrature_eigenstates_Abc, x=self.parameters.x, phi=self.parameters.phi
fn=triples.quadrature_eigenstates_Abc,
x=self.parameters.x,
phi=self.parameters.phi,
),
).representation

Expand Down
4 changes: 3 additions & 1 deletion mrmustard/lab_dev/states/squeezed_vacuum.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(
self._representation = self.from_ansatz(
modes=(mode,),
ansatz=PolyExpAnsatz.from_function(
fn=triples.squeezed_vacuum_state_Abc, r=self.parameters.r, phi=self.parameters.phi
fn=triples.squeezed_vacuum_state_Abc,
r=self.parameters.r,
phi=self.parameters.phi,
),
).representation
4 changes: 3 additions & 1 deletion mrmustard/lab_dev/transformations/s2gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(
modes_in=modes,
modes_out=modes,
ansatz=PolyExpAnsatz.from_function(
fn=triples.twomode_squeezing_gate_Abc, r=self.parameters.r, phi=self.parameters.phi
fn=triples.twomode_squeezing_gate_Abc,
r=self.parameters.r,
phi=self.parameters.phi,
),
).representation
4 changes: 3 additions & 1 deletion mrmustard/lab_dev/transformations/sgate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __init__(
modes_in=(mode,),
modes_out=(mode,),
ansatz=PolyExpAnsatz.from_function(
fn=triples.squeezing_gate_Abc, r=self.parameters.r, delta=self.parameters.phi
fn=triples.squeezing_gate_Abc,
r=self.parameters.r,
delta=self.parameters.phi,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should use the same name

),
).representation
5 changes: 4 additions & 1 deletion mrmustard/math/backend_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,16 +387,19 @@ def block(self, blocks: list[list[Tensor]], axes=(-2, -1)) -> Tensor:
"""
return self._apply("block", (blocks, axes))

def broadcast_to(self, array: Tensor, shape: tuple[int]) -> Tensor:
def broadcast_to(self, array: Tensor, shape: tuple[int], dtype=None) -> Tensor:
r"""Broadcasts an array to a new shape.

Args:
array: The array to broadcast.
shape: The shape to broadcast to.
dtype: The data type to cast to. If ``None``, the returned array
is of the same type as the given one.

Returns:
The broadcasted array.
"""
array = self.astensor(array, dtype)
return self._apply("broadcast_to", (array, shape))

def cast(self, array: Tensor, dtype=None) -> Tensor:
Expand Down
32 changes: 20 additions & 12 deletions mrmustard/physics/bargmann_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from mrmustard import math, settings
from mrmustard.physics.husimi import pq_to_aadag, wigner_to_husimi
from mrmustard.utils.typing import ComplexMatrix, Matrix, Vector, Scalar
from mrmustard.utils.typing import RealMatrix, ComplexMatrix, Matrix, Vector, Scalar


def bargmann_Abc_to_phasespace_cov_means(
Expand Down Expand Up @@ -217,33 +217,41 @@ def au2Symplectic(A):
return math.real(transformation @ S @ math.conj(math.transpose(transformation)))


def symplectic2Au(S):
def symplectic2Au(symplectic: RealMatrix) -> ComplexMatrix:
r"""
The inverse of au2Symplectic i.e., returns symplectic, given Au

S: symplectic in XXPP order
symplectic: symplectic in XXPP order
"""
m = S.shape[-1]
batch_size = symplectic.shape[:-2]
batch_shape = batch_size or (1,)
batch_dim = len(batch_shape)

symp_batch = math.broadcast_to(symplectic, batch_shape + symplectic.shape[-2:])

m = symp_batch.shape[-1]
m = m // 2
# the following lines of code transform the quadrature symplectic matrix to
# the annihilation one
R = math.rotmat(m)
S = R @ S @ math.dagger(R)
S = R @ symp_batch @ math.dagger(R)
# identifying blocks of S
S_1 = S[:m, :m]
S_2 = S[:m, m:]

# TODO: broadcasting/batch stuff consider a batch dimension
S_1 = S[..., :m, :m]
S_2 = S[..., :m, m:]

perm = tuple(range(len(S_1.shape)))
perm = perm[:batch_dim] + perm[batch_dim:][::-1]

# the formula to apply comes here
A_1 = S_2 @ math.conj(math.inv(S_1)) # use solve for inverse
A_2 = math.conj(math.inv(math.transpose(S_1)))
A_3 = math.transpose(A_2)
A_2 = math.conj(math.inv(math.transpose(S_1, perm)))
A_3 = math.transpose(A_2, perm)
A_4 = -math.conj(math.solve(S_1, S_2))

A = math.block([[A_1, A_2], [A_3, A_4]])
A = math.concat([math.concat([A_1, A_2], -1), math.concat([A_3, A_4], -1)], -2)

return A
return A if batch_size else A[0]


def XY_of_channel(A: ComplexMatrix):
Expand Down
140 changes: 107 additions & 33 deletions mrmustard/physics/symplectics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
"""
from __future__ import annotations

from typing import Iterable

from mrmustard import math
from mrmustard.utils.typing import Matrix
from .utils import compute_batch_size


def cxgate_symplectic(s: float) -> Matrix:
def cxgate_symplectic(s: float | Iterable[float]) -> Matrix:
r"""
The symplectic matrix of a controlled X gate.

Expand All @@ -32,12 +35,27 @@ def cxgate_symplectic(s: float) -> Matrix:
Returns:
The symplectic matrix of a CX gate.
"""
return math.astensor(
[[1, 0, 0, 0], [s, 1, 0, 0], [0, 0, 1, -s], [0, 0, 0, 1]], dtype="complex128"
batch_size, batch_dim = compute_batch_size(s)
batch_shape = batch_size or (1,)

s_batch = math.broadcast_to(math.cast(s, math.complex128), batch_shape)

O_matrix = math.zeros(batch_shape, math.complex128)
I_matrix = math.ones(batch_shape, math.complex128)

symplectic = math.stack(
[
math.stack([I_matrix, O_matrix, O_matrix, O_matrix], batch_dim),
math.stack([s_batch, I_matrix, O_matrix, O_matrix], batch_dim),
math.stack([O_matrix, O_matrix, I_matrix, -s_batch], batch_dim),
math.stack([O_matrix, O_matrix, O_matrix, I_matrix], batch_dim),
],
batch_dim,
)
return symplectic if batch_size else symplectic[0]


def czgate_symplectic(s: float) -> Matrix:
def czgate_symplectic(s: float | Iterable[float]) -> Matrix:
r"""
The symplectic matrix of a controlled Z gate.

Expand All @@ -47,7 +65,24 @@ def czgate_symplectic(s: float) -> Matrix:
Returns:
The symplectic matrix of a CZ gate.
"""
return math.astensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, s, 1, 0], [s, 0, 0, 1]])
batch_size, batch_dim = compute_batch_size(s)
batch_shape = batch_size or (1,)

s_batch = math.broadcast_to(math.cast(s, math.complex128), batch_shape)

O_matrix = math.zeros(batch_shape, math.complex128)
I_matrix = math.ones(batch_shape, math.complex128)

symplectic = math.stack(
[
math.stack([I_matrix, O_matrix, O_matrix, O_matrix], batch_dim),
math.stack([O_matrix, I_matrix, O_matrix, O_matrix], batch_dim),
math.stack([O_matrix, s_batch, I_matrix, O_matrix], batch_dim),
math.stack([s_batch, O_matrix, O_matrix, I_matrix], batch_dim),
],
batch_dim,
)
return symplectic if batch_size else symplectic[0]


def interferometer_symplectic(unitary: Matrix) -> Matrix:
Expand All @@ -60,12 +95,22 @@ def interferometer_symplectic(unitary: Matrix) -> Matrix:
Returns:
The symplectic matrix of an N-mode interferometer.
"""
return math.block(
[[math.real(unitary), -math.imag(unitary)], [math.imag(unitary), math.real(unitary)]]
batch_size = unitary.shape[:-2]
batch_shape = batch_size or (1,)
unitary_batch = math.broadcast_to(unitary, batch_shape + unitary.shape[-2:])
symplectic = math.concat(
[
math.concat([math.real(unitary_batch), -math.imag(unitary_batch)], -1),
math.concat([math.imag(unitary_batch), math.real(unitary_batch)], -1),
],
-2,
)
return symplectic if batch_size else symplectic[0]


def mzgate_symplectic(phi_a: float, phi_b: float, internal: bool) -> Matrix:
def mzgate_symplectic(
phi_a: float | Iterable[float], phi_b: float | Iterable[float], internal: bool
) -> Matrix:
r"""
The symplectic matrix of a Mach-Zehnder gate.

Expand All @@ -81,48 +126,69 @@ def mzgate_symplectic(phi_a: float, phi_b: float, internal: bool) -> Matrix:
Returns:
The symplectic matrix of a Mach-Zehnder gate.
"""
ca = math.cos(complex(phi_a))
sa = math.sin(complex(phi_a))
cb = math.cos(complex(phi_b))
sb = math.sin(complex(phi_b))
cp = math.cos(complex(phi_a + phi_b))
sp = math.sin(complex(phi_a + phi_b))
batch_size, batch_dim = compute_batch_size(phi_a, phi_b)
batch_shape = batch_size or (1,)

phi_a_batch = math.broadcast_to(phi_a, batch_shape)
phi_b_batch = math.broadcast_to(phi_b, batch_shape)

ca = math.cos(phi_a_batch)
sa = math.sin(phi_a_batch)
cb = math.cos(phi_b_batch)
sb = math.sin(phi_b_batch)
cp = math.cos(phi_a_batch + phi_b_batch)
sp = math.sin(phi_a_batch + phi_b_batch)
if internal:
return 0.5 * math.astensor(
symplectic = math.stack(
[
[ca - cb, -sa - sb, sb - sa, -ca - cb],
[-sa - sb, cb - ca, -ca - cb, sa - sb],
[sa - sb, ca + cb, ca - cb, -sa - sb],
[ca + cb, sb - sa, -sa - sb, cb - ca],
]
math.stack([ca - cb, -sa - sb, sb - sa, -ca - cb], batch_dim),
math.stack([-sa - sb, cb - ca, -ca - cb, sa - sb], batch_dim),
math.stack([sa - sb, ca + cb, ca - cb, -sa - sb], batch_dim),
math.stack([ca + cb, sb - sa, -sa - sb, cb - ca], batch_dim),
],
batch_dim,
)
else:
return 0.5 * math.astensor(
symplectic = math.stack(
[
[cp - ca, -sb, sa - sp, -1 - cb],
[-sa - sp, 1 - cb, -ca - cp, sb],
[sp - sa, 1 + cb, cp - ca, -sb],
[cp + ca, -sb, -sa - sp, 1 - cb],
]
math.stack([cp - ca, -sb, sa - sp, -1 - cb], batch_dim),
math.stack([-sa - sp, 1 - cb, -ca - cp, sb], batch_dim),
math.stack([sp - sa, 1 + cb, cp - ca, -sb], batch_dim),
math.stack([cp + ca, -sb, -sa - sp, 1 - cb], batch_dim),
],
batch_dim,
)
symplectic = math.cast(0.5 * symplectic, math.complex128)
return symplectic if batch_size else symplectic[0]


def pgate_symplectic(n_modes: int, shearing: float) -> Matrix:
def pgate_symplectic(n_modes: int, shearing: float | Iterable[float]) -> Matrix:
r"""
The symplectic matrix of a quadratic phase gate.

Args:
n_modes: The number of modes.
shearing: The shearing parameter.

Returns:
The symplectic matrix of a phase gate.
"""
return math.block(
batch_size, _ = compute_batch_size(shearing)
batch_shape = batch_size or (1,)

shearing_batch = math.broadcast_to(shearing, batch_shape)

I_matrix = math.broadcast_to(math.eye(n_modes), batch_shape + (n_modes, n_modes))
O_matrix = math.zeros(batch_shape + (n_modes, n_modes))

symplectic = math.concat(
[
[math.eye(n_modes), math.zeros((n_modes, n_modes))],
[math.eye(n_modes) * shearing, math.eye(n_modes)],
]
math.concat([I_matrix, O_matrix], -1),
math.concat([math.eye(n_modes) * shearing_batch[..., None, None], I_matrix], -1),
],
-2,
)
return symplectic if batch_size else symplectic[0]


def realinterferometer_symplectic(orthogonal: Matrix) -> Matrix:
Expand All @@ -135,6 +201,14 @@ def realinterferometer_symplectic(orthogonal: Matrix) -> Matrix:
Returns:
The symplectic matrix of an N-mode interferometer.
"""
return math.block(
[[orthogonal, -math.zeros_like(orthogonal)], [math.zeros_like(orthogonal), orthogonal]]
batch_size = orthogonal.shape[:-2]
batch_shape = batch_size or (1,)
orthogonal_batch = math.broadcast_to(orthogonal, batch_shape + orthogonal.shape[-2:])
symplectic = math.concat(
[
math.concat([orthogonal_batch, -math.zeros_like(orthogonal_batch)], -1),
math.concat([math.zeros_like(orthogonal_batch), orthogonal_batch], -1),
],
-2,
)
return symplectic if batch_size else symplectic[0]
Loading