From 3e85d6daa2414d140e7b00281e6aece2a21ab52f Mon Sep 17 00:00:00 2001 From: "Stefan J. Wernli" Date: Tue, 23 Apr 2024 09:02:03 -0700 Subject: [PATCH] Add `check_eq` for `StateDump` in Python (#1372) This adds a utility to the `StateDump` object in Python to help with writing tests that verify quantum state. The check ignores global phase, so allows for passing in any dictionary where the states differ from the dump by a constant factor, including unnormalized states. --- pip/qsharp/_qsharp.py | 35 ++++++++++++++++++++++++++++++++++- pip/tests/test_qsharp.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/pip/qsharp/_qsharp.py b/pip/qsharp/_qsharp.py index 75d148bb30..48eb06e984 100644 --- a/pip/qsharp/_qsharp.py +++ b/pip/qsharp/_qsharp.py @@ -10,7 +10,7 @@ Circuit, ) from warnings import warn -from typing import Any, Callable, Dict, Optional, TypedDict, Union, List +from typing import Any, Callable, Dict, Optional, Tuple, TypedDict, Union, List from .estimator._estimator import EstimatorResult, EstimatorParams import json @@ -349,6 +349,39 @@ def __str__(self) -> str: def _repr_html_(self) -> str: return self.__data._repr_html_() + def check_eq( + self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10 + ) -> bool: + """ + Checks if the state dump is equal to the given state. This is not mathematical equality, + as the check ignores global phase. + + :param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes, + or as a list of real amplitudes. + :param tolerance: The tolerance for the check. Defaults to 1e-10. + """ + phase = None + # Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes + if isinstance(state, list): + state = {i: state[i] for i in range(len(state))} + # Filter out zero states from the state dump and the given state based on tolerance + state = {k: v for k, v in state.items() if abs(v) > tolerance} + inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance} + if len(state) != len(inner_state): + return False + for key in state: + if key not in inner_state: + return False + if phase is None: + # Calculate the phase based on the first state pair encountered. + # Every pair of states after this must have the same phase for the states to be equivalent. + phase = inner_state[key] / state[key] + elif abs(phase - inner_state[key] / state[key]) > tolerance: + # This pair of states does not have the same phase, + # within tolerance, so the equivalence check fails. + return False + return True + def dump_machine() -> StateDump: """ diff --git a/pip/tests/test_qsharp.py b/pip/tests/test_qsharp.py index 4cd2316d84..719ac290d5 100644 --- a/pip/tests/test_qsharp.py +++ b/pip/tests/test_qsharp.py @@ -98,6 +98,34 @@ def test_dump_machine() -> None: # Check that the state dump correctly supports iteration and membership checks for idx in state_dump: assert idx in state_dump + # Check that the state dump is correct and equivalence check ignores global phase, allowing passing + # in of different, potentially unnormalized states. The state should be + # |01⟩: 0.7071+0.0000𝑖, |11⟩: −0.7071+0.0000𝑖 + assert state_dump.check_eq({1: complex(0.7071, 0.0), 3: complex(-0.7071, 0.0)}) + assert state_dump.check_eq({1: complex(0.0, 0.7071), 3: complex(0.0, -0.7071)}) + assert state_dump.check_eq({1: complex(0.5, 0.0), 3: complex(-0.5, 0.0)}) + assert state_dump.check_eq( + {1: complex(0.7071, 0.0), 3: complex(-0.7071, 0.0), 0: complex(0.0, 0.0)} + ) + assert state_dump.check_eq([0.0, 0.5, 0.0, -0.5]) + assert state_dump.check_eq([0.0, 0.5001, 0.00001, -0.5], tolerance=1e-3) + assert state_dump.check_eq( + [complex(0.0, 0.0), complex(0.0, -0.5), complex(0.0, 0.0), complex(0.0, 0.5)] + ) + assert not state_dump.check_eq({1: complex(0.7071, 0.0), 3: complex(0.7071, 0.0)}) + assert not state_dump.check_eq({1: complex(0.5, 0.0), 3: complex(0.0, 0.5)}) + assert not state_dump.check_eq({2: complex(0.5, 0.0), 3: complex(-0.5, 0.0)}) + assert not state_dump.check_eq([0.0, 0.5001, 0.0, -0.5], tolerance=1e-6) + # Reset the qubits and apply a small rotation to q1, to confirm that tolerance applies to the dump + # itself and not just the state. + qsharp.eval("ResetAll([q1, q2]);") + qsharp.eval("Ry(0.0001, q1);") + state_dump = qsharp.dump_machine() + assert state_dump.qubit_count == 2 + assert len(state_dump) == 2 + assert not state_dump.check_eq([1.0]) + assert state_dump.check_eq([0.99999999875, 0.0, 4.999999997916667e-05]) + assert state_dump.check_eq([1.0], tolerance=1e-4) def test_dump_operation() -> None: