Skip to content

Commit 039ce9a

Browse files
committed
fix: validate global_shift when creating eigen_gate.
1 parent 0f21e0a commit 039ce9a

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

cirq-core/cirq/ops/parity_gates.py

+30
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,36 @@ class XXPowGate(gate_features.InterchangeableQubitsGate, eigen_gate.EigenGate):
6969
implemented via this class.
7070
"""
7171

72+
73+
def __init__(
74+
self, *, exponent: value.TParamVal = 1.0, global_shift: float = 0.0
75+
):
76+
"""Initialize an XXPowGate.
77+
78+
Args:
79+
exponent: The t in gate**t. Determines how much the eigenvalues of
80+
the gate are phased by. For example, eigenvectors phased by -1
81+
when `gate**1` is applied will gain a relative phase of
82+
e^{i pi exponent} when `gate**exponent` is applied (relative to
83+
eigenvectors unaffected by `gate**1`).
84+
global_shift: Offsets the eigenvalues of the gate at exponent=1.
85+
In effect, this controls a global phase factor on the gate's
86+
unitary matrix. The factor is:
87+
88+
exp(i * pi * global_shift * exponent)
89+
90+
For example, `cirq.X**t` uses a `global_shift` of 0 but
91+
`cirq.rx(t)` uses a `global_shift` of -0.5, which is why
92+
`cirq.unitary(cirq.rx(pi))` equals -iX instead of X.
93+
94+
Raises:
95+
ValueError: If the supplied exponent is a complex number with an
96+
imaginary component or global_shift is out of range.
97+
"""
98+
super().__init__(exponent=exponent, global_shift=global_shift)
99+
if global_shift <= -2.0 or global_shift >= 2.0:
100+
raise ValueError(f"Gate global shift must be in the range (-2,2). Invalid Value: {global_shift}")
101+
72102
def _num_qubits_(self) -> int:
73103
return 2
74104

cirq-core/cirq/ops/parity_gates_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def test_xx_init():
3030
assert cirq.XXPowGate(exponent=1).exponent == 1
3131
v = cirq.XXPowGate(exponent=0.5)
3232
assert v.exponent == 0.5
33-
33+
with pytest.raises(ValueError, match="in the range"):
34+
assert cirq.XXPowGate(exponent=0.5, global_shift=4)
35+
assert cirq.XXPowGate(exponent=0.5, global_shift=-0.5).global_shift == -0.5
3436

3537
def test_xx_eq():
3638
eq = cirq.testing.EqualsTester()

0 commit comments

Comments
 (0)