diff --git a/horqrux/analog.py b/horqrux/analog.py index c8acb4e..c78dab4 100644 --- a/horqrux/analog.py +++ b/horqrux/analog.py @@ -21,7 +21,7 @@ class _HamiltonianEvolution(Primitive): target: QubitSupport control: QubitSupport - def unitary(self, values: dict[str, Array] = dict()) -> Array: + def _unitary(self, values: dict[str, Array] = dict()) -> Array: return expm(values["hamiltonian"] * (-1j * values["time_evolution"])) diff --git a/horqrux/parametric.py b/horqrux/parametric.py index 0f37914..2cc66a0 100644 --- a/horqrux/parametric.py +++ b/horqrux/parametric.py @@ -65,7 +65,7 @@ def __iter__(self) -> Iterable: def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: return cls(*children, *aux_data) - def unitary(self, values: dict[str, float] = dict()) -> Array: + def _unitary(self, values: dict[str, float] = dict()) -> Array: return _unitary(OPERATIONS_DICT[self.generator_name], self.parse_values(values)) def jacobian(self, values: dict[str, float] = dict()) -> Array: @@ -141,7 +141,7 @@ def RZ( class _PHASE(Parametric): - def unitary(self, values: dict[str, float] = dict()) -> Array: + def _unitary(self, values: dict[str, float] = dict()) -> Array: u = jnp.eye(2, 2, dtype=default_dtype) u = u.at[(1, 1)].set(jnp.exp(1.0j * self.parse_values(values))) return u diff --git a/horqrux/primitive.py b/horqrux/primitive.py index 5f5aedf..11f75b7 100644 --- a/horqrux/primitive.py +++ b/horqrux/primitive.py @@ -14,6 +14,7 @@ QubitSupport, TargetQubits, _dagger, + controlled, is_controlled, none_like, ) @@ -60,11 +61,41 @@ def tree_flatten(self) -> tuple[tuple, tuple[str, TargetQubits, ControlQubits, N def tree_unflatten(cls, aux_data: Any, children: Any) -> Any: return cls(*children, *aux_data) - def unitary(self, values: dict[str, float] = dict()) -> Array: + def _unitary(self, values: dict[str, float] = dict()) -> Array: + """Obtain the base unitary from `generator_name`. + + Args: + values (dict[str, float], optional): Parameter values. Defaults to dict(). + + Returns: + Array: The base unitary from `generator_name`. + """ return OPERATIONS_DICT[self.generator_name] def dagger(self, values: dict[str, float] = dict()) -> Array: - return _dagger(self.unitary(values)) + """Obtain the dagger of the base unitary from `generator_name`. + + Args: + values (dict[str, float], optional): Parameter values. Defaults to dict(). + + Returns: + Array: The base unitary daggered from `generator_name`. + """ + return _dagger(self._unitary(values)) + + def tensor(self, values: dict[str, float] = dict()) -> Array: + """Obtain the unitary taking into account the qubit support for controlled operations. + + Args: + values (dict[str, float], optional): Parameter values. Defaults to dict(). + + Returns: + Array: Unitary representation taking into account the qubit support. + """ + base_unitary = self._unitary(values) + if is_controlled(self.control): + return controlled(base_unitary, self.target, self.control) + return base_unitary @property def name(self) -> str: diff --git a/horqrux/shots.py b/horqrux/shots.py index 38a7c14..3918946 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -29,7 +29,7 @@ def to_matrix( observable.control == observable.parse_idx(none_like(observable.target)), "Controlled gates cannot be promoted from observables to operations on the whole state vector", ) - unitary = observable.unitary(values=values) + unitary = observable._unitary(values=values) target = observable.target[0][0] identity = jnp.eye(2, dtype=unitary.dtype) ops = [identity for _ in range(n_qubits)] diff --git a/horqrux/utils.py b/horqrux/utils.py index 1b04833..ac9b724 100644 --- a/horqrux/utils.py +++ b/horqrux/utils.py @@ -3,7 +3,8 @@ from collections import Counter from dataclasses import dataclass from enum import Enum -from functools import singledispatch +from functools import reduce, singledispatch +from math import log from typing import Any, Iterable, Union import jax @@ -15,6 +16,7 @@ from numpy import log2 from ._misc import default_complex_dtype +from .matrices import _I default_dtype = default_complex_dtype() @@ -97,7 +99,7 @@ def list(cls) -> list[str]: class OperationType(StrEnum): - UNITARY = "unitary" + UNITARY = "_unitary" DAGGER = "dagger" JACOBIAN = "jacobian" @@ -151,12 +153,109 @@ def _jacobian(generator: Array, theta: float) -> Array: def _controlled(operator: Array, n_control: int) -> Array: + """ + Create a controlled quantum operator with specified number of control qubits. + + Args: + operator (jnp.ndarray): The base quantum operator to be controlled. + n_control (int): Number of control qubits. + + Returns: + jnp.ndarray: The controlled quantum operator matrix + """ n_qubits = int(log2(operator.shape[0])) control = jnp.eye(2 ** (n_control + n_qubits), dtype=default_dtype) control = control.at[-(2**n_qubits) :, -(2**n_qubits) :].set(operator) return control +def controlled( + operator: jnp.ndarray, + target_qubits: TargetQubits, + control_qubits: ControlQubits, +) -> jnp.ndarray: + """ + Create a controlled quantum operator with specified control and target qubit indices. + + Args: + operator (jnp.ndarray): The base quantum operator to be controlled. + Note the operator is defined only on `target_qubits`. + control_qubits (int or tuple of ints): Index or indices of control qubits + target_qubits (int or tuple of ints): Index or indices of target qubits + + Returns: + jnp.ndarray: The controlled quantum operator matrix + """ + controls: tuple = tuple() + targets: tuple = tuple() + if isinstance(control_qubits[0], tuple): + controls = control_qubits[0] + if isinstance(target_qubits[0], tuple): + targets = target_qubits[0] + n_qop = int(log(operator.shape[0], 2)) + n_targets = len(targets) + if n_qop != n_targets: + raise ValueError("`target_qubits` length should match the shape of operator.") + # Determine the total number of qubits and order of controls + ntotal_qubits = len(controls) + n_targets + qubit_support = sorted(controls + targets) + control_ind_support = tuple(i for i, q in enumerate(qubit_support) if q in controls) + + # Create the full Hilbert space dimension + full_dim = 2**ntotal_qubits + + # Initialize the controlled operator as an identity matrix + controlled_op = jnp.eye(full_dim, dtype=operator.dtype) + + # Compute the control mask using bit manipulation + control_mask = jnp.sum( + jnp.array( + [1 << (ntotal_qubits - control_qubit - 1) for control_qubit in control_ind_support] + ) + ) + + # Create indices for the controlled subspace + indices = jnp.arange(full_dim) + controlled_indices = indices[(indices & control_mask) == control_mask] + + # Set the controlled subspace to the operator + controlled_op = controlled_op.at[jnp.ix_(controlled_indices, controlled_indices)].set(operator) + + return controlled_op + + +def expand_operator( + operator: Array, qubit_support: TargetQubits, full_support: TargetQubits +) -> Array: + """ + Expands an operator acting on a given qubit_support to act on a larger full_support + by explicitly filling in identity matrices on all remaining qubits. + + Args: + operator (Array): Operator to expand + qubit_support (TargetQubits): Qubit support the operator is initially defined over. + full_support (TargetQubits): Qubit support the operator will be defined over. + + Raises: + ValueError: When `full_support` larger than or equal to the `qubit_support` + + Returns: + Array: Expanded operator. + """ + full_support = tuple(sorted(full_support)) + qubit_support = tuple(sorted(qubit_support)) + if not set(qubit_support).issubset(set(full_support)): + raise ValueError( + "Expanding tensor operation requires a `full_support` argument " + "larger than or equal to the `qubit_support`." + ) + + kron_qubits = set(full_support) - set(qubit_support) + kron_operator = reduce(jnp.kron, [operator] + [_I] * len(kron_qubits)) + # TODO: Add permute_basis + return kron_operator + + def product_state(bitstring: str) -> Array: """Generates a state of shape [2 for _ in range(len(bitstring))]. diff --git a/mkdocs.yml b/mkdocs.yml index 0f214f4..272b3e0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,7 +49,7 @@ plugins: default_handler: python handlers: python: - selection: + options: filters: - "!^_" # exlude all members starting with _ - "^__init__$" # but always include __init__ modules and methods diff --git a/tests/test_gates.py b/tests/test_gates.py index afb737d..57c94cb 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -10,7 +10,7 @@ from horqrux.apply import apply_gate, apply_operator from horqrux.parametric import PHASE, RX, RY, RZ from horqrux.primitive import NOT, SWAP, H, I, S, T, X, Y, Z -from horqrux.utils import density_mat, equivalent_state, product_state, random_state +from horqrux.utils import OperationType, density_mat, equivalent_state, product_state, random_state MAX_QUBITS = 7 PARAMETRIC_GATES = (RX, RY, RZ, PHASE) @@ -31,7 +31,7 @@ def test_primitive(gate_fn: Callable) -> None: # test density matrix is similar to pure state dm = apply_operator( density_mat(orig_state), - gate.unitary(), + gate._unitary(), gate.target[0], gate.control[0], ) @@ -54,7 +54,7 @@ def test_controlled_primitive(gate_fn: Callable) -> None: # test density matrix is similar to pure state dm = apply_operator( density_mat(orig_state), - gate.unitary(), + gate._unitary(), gate.target[0], gate.control[0], ) @@ -75,7 +75,7 @@ def test_parametric(gate_fn: Callable) -> None: # test density matrix is similar to pure state dm = apply_operator( density_mat(orig_state), - gate.unitary(values), + gate._unitary(values), gate.target[0], gate.control[0], ) @@ -99,7 +99,7 @@ def test_controlled_parametric(gate_fn: Callable) -> None: # test density matrix is similar to pure state dm = apply_operator( density_mat(orig_state), - gate.unitary(values), + gate._unitary(values), gate.target[0], gate.control[0], ) @@ -149,9 +149,81 @@ def test_merge_gates() -> None: "c": np.random.uniform(0.1, 2 * np.pi), } state_grouped = apply_gate( - product_state("0000"), gates, values, "unitary", group_gates=True, merge_ops=True + product_state("0000"), + gates, + values, + OperationType.UNITARY, + group_gates=True, + merge_ops=True, ) state = apply_gate( - product_state("0000"), gates, values, "unitary", group_gates=False, merge_ops=False + product_state("0000"), + gates, + values, + OperationType.UNITARY, + group_gates=False, + merge_ops=False, ) assert jnp.allclose(state_grouped, state) + + +def flip_bit_wrt_control(bitstring: str, control: int, target: int) -> str: + # Convert bitstring to list for easier manipulation + bits = list(bitstring) + + # Flip the bit at the specified index + if bits[control] == "1": + bits[target] = "0" if bits[target] == "1" else "1" + + # Convert back to string + return "".join(bits) + + +@pytest.mark.parametrize( + "bitstring", + [ + "00", + "01", + "11", + "10", + ], +) +def test_cnot_product_state(bitstring: str): + cnot0 = NOT(target=1, control=0) + state = product_state(bitstring) + state = apply_gate(state, cnot0) + expected_state = product_state(flip_bit_wrt_control(bitstring, 0, 1)) + assert jnp.allclose(state, expected_state) + + # reverse control and target + cnot1 = NOT(target=0, control=1) + state = product_state(bitstring) + state = apply_gate(state, cnot1) + expected_state = product_state(flip_bit_wrt_control(bitstring, 1, 0)) + assert jnp.allclose(state, expected_state) + + +def test_cnot_tensor() -> None: + cnot0 = NOT(target=1, control=0) + cnot1 = NOT(target=0, control=1) + assert jnp.allclose( + cnot0.tensor(), jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + ) + assert jnp.allclose( + cnot1.tensor(), jnp.array([[1, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0]]) + ) + + +def test_crx_tensor() -> None: + crx0 = RX(0.2, target=1, control=0) + crx1 = RX(0.2, target=0, control=1) + assert jnp.allclose( + crx0.tensor(), + jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0.9950, -0.0998j], [0, 0, -0.0998j, 0.9950]]), + atol=1e-3, + ) + assert jnp.allclose( + crx1.tensor(), + jnp.array([[1, 0, 0, 0], [0, 0.9950, 0, -0.0998j], [0, 0, 1, 0], [0, -0.0998j, 0, 0.9950]]), + atol=1e-3, + )