Skip to content

Commit

Permalink
[Bugfix] Tensor representation of controlled operations with qubit su…
Browse files Browse the repository at this point in the history
…pport order (#40)

* add single dispatch control

* add tensor method

* merge conflict

* docstr tensor

* values description

* check tensor on a crx gate

* add _

* make unitary private

* change options selection mkdocs

* fix docstr expand operatr

---------
  • Loading branch information
chMoussa authored Feb 12, 2025
1 parent c61e84d commit f8d0cae
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 16 deletions.
2 changes: 1 addition & 1 deletion horqrux/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class _HamiltonianEvolution(Primitive):
target: QubitSupport
control: QubitSupport

def unitary(self, values: dict[str, Array] = dict()) -> Array:
def _unitary(self, values: dict[str, Array] = dict()) -> Array:
return expm(values["hamiltonian"] * (-1j * values["time_evolution"]))


Expand Down
4 changes: 2 additions & 2 deletions horqrux/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __iter__(self) -> Iterable:
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*children, *aux_data)

def unitary(self, values: dict[str, float] = dict()) -> Array:
def _unitary(self, values: dict[str, float] = dict()) -> Array:
return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values))

def jacobian(self, values: dict[str, float] = dict()) -> Array:
Expand Down Expand Up @@ -141,7 +141,7 @@ def RZ(


class _PHASE(Parametric):
def unitary(self, values: dict[str, float] = dict()) -> Array:
def _unitary(self, values: dict[str, float] = dict()) -> Array:
u = jnp.eye(2, 2, dtype=default_dtype)
u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values)))
return u
Expand Down
35 changes: 33 additions & 2 deletions horqrux/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
QubitSupport,
TargetQubits,
_dagger,
controlled,
is_controlled,
none_like,
)
Expand Down Expand Up @@ -60,11 +61,41 @@ def tree_flatten(self) -> tuple[tuple, tuple[str, TargetQubits, ControlQubits, N
def tree_unflatten(cls, aux_data: Any, children: Any) -> Any:
return cls(*children, *aux_data)

def unitary(self, values: dict[str, float] = dict()) -> Array:
def _unitary(self, values: dict[str, float] = dict()) -> Array:
"""Obtain the base unitary from `generator_name`.
Args:
values (dict[str, float], optional): Parameter values. Defaults to dict().
Returns:
Array: The base unitary from `generator_name`.
"""
return OPERATIONS_DICT[self.generator_name]

def dagger(self, values: dict[str, float] = dict()) -> Array:
return _dagger(self.unitary(values))
"""Obtain the dagger of the base unitary from `generator_name`.
Args:
values (dict[str, float], optional): Parameter values. Defaults to dict().
Returns:
Array: The base unitary daggered from `generator_name`.
"""
return _dagger(self._unitary(values))

def tensor(self, values: dict[str, float] = dict()) -> Array:
"""Obtain the unitary taking into account the qubit support for controlled operations.
Args:
values (dict[str, float], optional): Parameter values. Defaults to dict().
Returns:
Array: Unitary representation taking into account the qubit support.
"""
base_unitary = self._unitary(values)
if is_controlled(self.control):
return controlled(base_unitary, self.target, self.control)
return base_unitary

@property
def name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion horqrux/shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def to_matrix(
observable.control == observable.parse_idx(none_like(observable.target)),
"Controlled gates cannot be promoted from observables to operations on the whole state vector",
)
unitary = observable.unitary(values=values)
unitary = observable._unitary(values=values)
target = observable.target[0][0]
identity = jnp.eye(2, dtype=unitary.dtype)
ops = [identity for _ in range(n_qubits)]
Expand Down
103 changes: 101 additions & 2 deletions horqrux/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from collections import Counter
from dataclasses import dataclass
from enum import Enum
from functools import singledispatch
from functools import reduce, singledispatch
from math import log
from typing import Any, Iterable, Union

import jax
Expand All @@ -15,6 +16,7 @@
from numpy import log2

from ._misc import default_complex_dtype
from .matrices import _I

default_dtype = default_complex_dtype()

Expand Down Expand Up @@ -97,7 +99,7 @@ def list(cls) -> list[str]:


class OperationType(StrEnum):
UNITARY = "unitary"
UNITARY = "_unitary"
DAGGER = "dagger"
JACOBIAN = "jacobian"

Expand Down Expand Up @@ -151,12 +153,109 @@ def _jacobian(generator: Array, theta: float) -> Array:


def _controlled(operator: Array, n_control: int) -> Array:
"""
Create a controlled quantum operator with specified number of control qubits.
Args:
operator (jnp.ndarray): The base quantum operator to be controlled.
n_control (int): Number of control qubits.
Returns:
jnp.ndarray: The controlled quantum operator matrix
"""
n_qubits = int(log2(operator.shape[0]))
control = jnp.eye(2 ** (n_control + n_qubits), dtype=default_dtype)
control = control.at[-(2**n_qubits) :, -(2**n_qubits) :].set(operator)
return control


def controlled(
operator: jnp.ndarray,
target_qubits: TargetQubits,
control_qubits: ControlQubits,
) -> jnp.ndarray:
"""
Create a controlled quantum operator with specified control and target qubit indices.
Args:
operator (jnp.ndarray): The base quantum operator to be controlled.
Note the operator is defined only on `target_qubits`.
control_qubits (int or tuple of ints): Index or indices of control qubits
target_qubits (int or tuple of ints): Index or indices of target qubits
Returns:
jnp.ndarray: The controlled quantum operator matrix
"""
controls: tuple = tuple()
targets: tuple = tuple()
if isinstance(control_qubits[0], tuple):
controls = control_qubits[0]
if isinstance(target_qubits[0], tuple):
targets = target_qubits[0]
n_qop = int(log(operator.shape[0], 2))
n_targets = len(targets)
if n_qop != n_targets:
raise ValueError("`target_qubits` length should match the shape of operator.")
# Determine the total number of qubits and order of controls
ntotal_qubits = len(controls) + n_targets
qubit_support = sorted(controls + targets)
control_ind_support = tuple(i for i, q in enumerate(qubit_support) if q in controls)

# Create the full Hilbert space dimension
full_dim = 2**ntotal_qubits

# Initialize the controlled operator as an identity matrix
controlled_op = jnp.eye(full_dim, dtype=operator.dtype)

# Compute the control mask using bit manipulation
control_mask = jnp.sum(
jnp.array(
[1 << (ntotal_qubits - control_qubit - 1) for control_qubit in control_ind_support]
)
)

# Create indices for the controlled subspace
indices = jnp.arange(full_dim)
controlled_indices = indices[(indices & control_mask) == control_mask]

# Set the controlled subspace to the operator
controlled_op = controlled_op.at[jnp.ix_(controlled_indices, controlled_indices)].set(operator)

return controlled_op


def expand_operator(
operator: Array, qubit_support: TargetQubits, full_support: TargetQubits
) -> Array:
"""
Expands an operator acting on a given qubit_support to act on a larger full_support
by explicitly filling in identity matrices on all remaining qubits.
Args:
operator (Array): Operator to expand
qubit_support (TargetQubits): Qubit support the operator is initially defined over.
full_support (TargetQubits): Qubit support the operator will be defined over.
Raises:
ValueError: When `full_support` larger than or equal to the `qubit_support`
Returns:
Array: Expanded operator.
"""
full_support = tuple(sorted(full_support))
qubit_support = tuple(sorted(qubit_support))
if not set(qubit_support).issubset(set(full_support)):
raise ValueError(
"Expanding tensor operation requires a `full_support` argument "
"larger than or equal to the `qubit_support`."
)

kron_qubits = set(full_support) - set(qubit_support)
kron_operator = reduce(jnp.kron, [operator] + [_I] * len(kron_qubits))
# TODO: Add permute_basis
return kron_operator


def product_state(bitstring: str) -> Array:
"""Generates a state of shape [2 for _ in range(len(bitstring))].
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ plugins:
default_handler: python
handlers:
python:
selection:
options:
filters:
- "!^_" # exlude all members starting with _
- "^__init__$" # but always include __init__ modules and methods
Expand Down
86 changes: 79 additions & 7 deletions tests/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from horqrux.apply import apply_gate, apply_operator
from horqrux.parametric import PHASE, RX, RY, RZ
from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z
from horqrux.utils import density_mat, equivalent_state, product_state, random_state
from horqrux.utils import OperationType, density_mat, equivalent_state, product_state, random_state

MAX_QUBITS = 7
PARAMETRIC_GATES = (RX, RY, RZ, PHASE)
Expand All @@ -31,7 +31,7 @@ def test_primitive(gate_fn: Callable) -> None:
# test density matrix is similar to pure state
dm = apply_operator(
density_mat(orig_state),
gate.unitary(),
gate._unitary(),
gate.target[0],
gate.control[0],
)
Expand All @@ -54,7 +54,7 @@ def test_controlled_primitive(gate_fn: Callable) -> None:
# test density matrix is similar to pure state
dm = apply_operator(
density_mat(orig_state),
gate.unitary(),
gate._unitary(),
gate.target[0],
gate.control[0],
)
Expand All @@ -75,7 +75,7 @@ def test_parametric(gate_fn: Callable) -> None:
# test density matrix is similar to pure state
dm = apply_operator(
density_mat(orig_state),
gate.unitary(values),
gate._unitary(values),
gate.target[0],
gate.control[0],
)
Expand All @@ -99,7 +99,7 @@ def test_controlled_parametric(gate_fn: Callable) -> None:
# test density matrix is similar to pure state
dm = apply_operator(
density_mat(orig_state),
gate.unitary(values),
gate._unitary(values),
gate.target[0],
gate.control[0],
)
Expand Down Expand Up @@ -149,9 +149,81 @@ def test_merge_gates() -> None:
"c": np.random.uniform(0.1, 2 * np.pi),
}
state_grouped = apply_gate(
product_state("0000"), gates, values, "unitary", group_gates=True, merge_ops=True
product_state("0000"),
gates,
values,
OperationType.UNITARY,
group_gates=True,
merge_ops=True,
)
state = apply_gate(
product_state("0000"), gates, values, "unitary", group_gates=False, merge_ops=False
product_state("0000"),
gates,
values,
OperationType.UNITARY,
group_gates=False,
merge_ops=False,
)
assert jnp.allclose(state_grouped, state)


def flip_bit_wrt_control(bitstring: str, control: int, target: int) -> str:
# Convert bitstring to list for easier manipulation
bits = list(bitstring)

# Flip the bit at the specified index
if bits[control] == "1":
bits[target] = "0" if bits[target] == "1" else "1"

# Convert back to string
return "".join(bits)


@pytest.mark.parametrize(
"bitstring",
[
"00",
"01",
"11",
"10",
],
)
def test_cnot_product_state(bitstring: str):
cnot0 = NOT(target=1, control=0)
state = product_state(bitstring)
state = apply_gate(state, cnot0)
expected_state = product_state(flip_bit_wrt_control(bitstring, 0, 1))
assert jnp.allclose(state, expected_state)

# reverse control and target
cnot1 = NOT(target=0, control=1)
state = product_state(bitstring)
state = apply_gate(state, cnot1)
expected_state = product_state(flip_bit_wrt_control(bitstring, 1, 0))
assert jnp.allclose(state, expected_state)


def test_cnot_tensor() -> None:
cnot0 = NOT(target=1, control=0)
cnot1 = NOT(target=0, control=1)
assert jnp.allclose(
cnot0.tensor(), jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]])
)
assert jnp.allclose(
cnot1.tensor(), jnp.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]])
)


def test_crx_tensor() -> None:
crx0 = RX(0.2, target=1, control=0)
crx1 = RX(0.2, target=0, control=1)
assert jnp.allclose(
crx0.tensor(),
jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0.9950, -0.0998j], [0, 0, -0.0998j, 0.9950]]),
atol=1e-3,
)
assert jnp.allclose(
crx1.tensor(),
jnp.array([[1, 0, 0, 0], [0, 0.9950, 0, -0.0998j], [0, 0, 1, 0], [0, -0.0998j, 0, 0.9950]]),
atol=1e-3,
)

0 comments on commit f8d0cae

Please sign in to comment.