Skip to content

Commit e0b2174

Browse files
Fixes Wires objects as wire labels bug (#6933)
**Context:** Prior to this PR we had the following behaviour, ```python a = qml.wires.Wires("a") b = qml.wires.Wires("b") wires = [a,b,"c"] >>> print(qml.wires.Wires(wires)) Wires([Wires(['a']), Wires(['b']), 'c']) >>> print(qml.wires.Wires([0, 1, 2, 3, qml.wires.Wires([4, 5]), None])) Wires([0, 1, 2, 3, Wires([4, 5]), None]) >>> print(qml.wires.Wires([qml.wires.Wires([(0,0), (0,1)])])) Wires([Wires([(0, 0), (0, 1)])]) ``` **Description of the Change:** Add handling to `_process` that dissolves any `Wires` objects that survived to that point. This results in improved behaviour like, ```python a = qml.wires.Wires("a") b = qml.wires.Wires("b") wires = [a,b,"c"] >>> print(qml.wires.Wires(wires)) Wires(['a', 'b', 'c']) >>> print(qml.wires.Wires([0, 1, 2, 3, qml.wires.Wires([4, 5]), None])) Wires([0, 1, 2, 3, 4, 5, None]) >>> print(qml.wires.Wires([qml.wires.Wires([(0,0), (0,1)])])) Wires([(0, 0), (0, 1)]) ``` **Benefits:** Better visualization of wire objects. **Possible Drawbacks:** Wire objects will look slightly different depending on your workflow. **Related GitHub Issues:** Fixes #6669 [sc-79749] --------- Co-authored-by: Isaac De Vlugt <[email protected]>
1 parent 2bb3e4d commit e0b2174

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

doc/releases/changelog-dev.md

+4
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@
282282

283283
<h3>Bug fixes 🐛</h3>
284284

285+
* Fixed `qml.wires.Wires` initialization to disallow `Wires` objects as wires labels.
286+
Now, `Wires` is idempotent, e.g. `Wires([Wires([0]), Wires([1])])==Wires([0, 1])`.
287+
[(#6933)](https://github.com/PennyLaneAI/pennylane/pull/6933)
288+
285289
* `qml.capture.PlxprInterpreter` now correctly handles propagation of constants when interpreting higher-order primitives
286290
[(#6913)](https://github.com/PennyLaneAI/pennylane/pull/6913)
287291

pennylane/wires.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def _process(wires):
9393
if len(set_of_wires) != len(tuple_of_wires):
9494
raise WireError(f"Wires must be unique; got {wires}.")
9595

96-
return tuple_of_wires
96+
# required to make `Wires` object idempotent
97+
return tuple(itertools.chain(*(_flatten_wires_object(x) for x in tuple_of_wires)))
9798

9899

99100
class Wires(Sequence):
@@ -120,7 +121,7 @@ class Wires(Sequence):
120121
"""
121122

122123
def _flatten(self):
123-
"""Serialize Wires into a flattened representation according to the PyTree convension."""
124+
"""Serialize Wires into a flattened representation according to the PyTree convention."""
124125
return self._labels, ()
125126

126127
@classmethod
@@ -731,5 +732,13 @@ def __rxor__(self, other):
731732

732733
WiresLike = Union[Wires, Iterable[Hashable], Hashable]
733734

735+
736+
def _flatten_wires_object(wire_label):
737+
"""Converts the input to a tuple of wire labels."""
738+
if isinstance(wire_label, Wires):
739+
return wire_label.labels
740+
return [wire_label]
741+
742+
734743
# Register Wires as a PyTree-serializable class
735744
register_pytree(Wires, Wires._flatten, Wires._unflatten) # pylint: disable=protected-access

tests/test_wires.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@
3434
class TestWires:
3535
"""Tests for the ``Wires`` class."""
3636

37+
def test_wires_object_as_label(self):
38+
"""Tests that a Wires object can be used as a label for another Wires object."""
39+
assert Wires([0, 1]) == Wires([Wires([0]), Wires([1])])
40+
assert Wires(["a", "b", 1]) == Wires([Wires(["a", "b"]), Wires([1])])
41+
assert Wires([Wires([(0, 0), (0, 1)])]) == Wires([(0, 0), (0, 1)])
42+
3743
def test_error_if_wires_none(self):
3844
"""Tests that a TypeError is raised if None is given as wires."""
3945
with pytest.raises(TypeError, match="Must specify a set of wires."):
@@ -74,7 +80,7 @@ def test_creation_from_wires_lists(self):
7480
"""Tests that a Wires object can be created from a list of Wires."""
7581

7682
wires = Wires([Wires([0]), Wires([1]), Wires([2])])
77-
assert wires.labels == (Wires([0]), Wires([1]), Wires([2]))
83+
assert wires.labels == (0, 1, 2)
7884

7985
@pytest.mark.parametrize(
8086
"iterable", [[1, 0, 4], ["a", "b", "c"], [0, 1, None], ["a", 1, "ancilla"]]
@@ -148,7 +154,7 @@ def test_contains(
148154
wires = Wires([0, 1, 2, 3, Wires([4, 5]), None])
149155

150156
assert 0 in wires
151-
assert Wires([4, 5]) in wires
157+
assert Wires([4, 5]) not in wires
152158
assert None in wires
153159
assert Wires([1]) not in wires
154160
assert Wires([0, 3]) not in wires
@@ -170,7 +176,7 @@ def test_contains_wires(
170176

171177
assert not wires.contains_wires(0) # wrong type
172178
assert not wires.contains_wires([0, 1]) # wrong type
173-
assert not wires.contains_wires(
179+
assert wires.contains_wires(
174180
Wires([4, 5])
175181
) # looks up 4 and 5 in wires, which are not present
176182

0 commit comments

Comments
 (0)