Skip to content

Commit dcc4da9

Browse files
authored
Tweak swap update algorithm to ensure termination (#164)
* add a test case that reproduces the swap updater infinite loop, and fix it by keeping track of previously used swaps to prevent cycles
1 parent 2c2a514 commit dcc4da9

File tree

2 files changed

+76
-6
lines changed

2 files changed

+76
-6
lines changed

recirq/quantum_chess/swap_updater.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ def __init__(self,
106106
self.dlists = mcpe.DependencyLists(circuit)
107107
self.mapping = mcpe.QubitMapping(initial_mapping)
108108
self.swap_factory = swap_factory
109-
self.pairwise_distances = _pairwise_shortest_distances(self.device_qubits)
109+
self.pairwise_distances = _pairwise_shortest_distances(
110+
self.device_qubits)
111+
# Tracks swaps that have been made since the last circuit gate was
112+
# output.
113+
self.prev_swaps = set()
110114

111115
def _distance_between(self, q1: cirq.GridQubit, q2: cirq.GridQubit) -> int:
112116
"""Returns the precomputed length of the shortest path between two qubits."""
@@ -123,11 +127,12 @@ def generate_candidate_swaps(
123127
"""
124128
for gate in gates:
125129
for gate_q in gate.qubits:
126-
yield from (
127-
(gate_q, swap_q)
128-
for swap_q in gate_q.neighbors(self.device_qubits)
129-
if mcpe.effect_of_swap((gate_q, swap_q), gate.qubits,
130-
self._distance_between) > 0)
130+
for swap_q in gate_q.neighbors(self.device_qubits):
131+
swap_qubits = (gate_q, swap_q)
132+
effect = mcpe.effect_of_swap(swap_qubits, gate.qubits,
133+
self._distance_between)
134+
if swap_qubits not in self.prev_swaps and effect > 0:
135+
yield swap_qubits
131136

132137
def _mcpe(self, swap_q1: cirq.GridQubit, swap_q2: cirq.GridQubit) -> int:
133138
"""Returns the maximum consecutive positive effect of swapping two qubits."""
@@ -152,6 +157,7 @@ def update_iteration(self) -> Generator[cirq.Operation, None, None]:
152157
# and added to the final circuit.
153158
for q in gate.qubits:
154159
self.dlists.pop_front(q)
160+
self.prev_swaps.clear()
155161
yield physical_gate
156162
else:
157163
# physical_gate needs to be fixed up with some swaps.
@@ -171,6 +177,8 @@ def update_iteration(self) -> Generator[cirq.Operation, None, None]:
171177
# connectivity graph.
172178
raise ValueError("no swaps founds that will improve the circuit")
173179
chosen_swap = max(candidates, key=lambda swap: self._mcpe(*swap))
180+
self.prev_swaps.add(chosen_swap)
181+
self.prev_swaps.add(tuple(reversed(chosen_swap)))
174182
self.mapping.swap_physical(*chosen_swap)
175183
yield from self.swap_factory(*chosen_swap)
176184

recirq/quantum_chess/swap_updater_test.py

+62
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,65 @@ def test_disconnected_components():
196196
cirq.Circuit(
197197
SwapUpdater(circuit, allowed_qubits,
198198
initial_mapping).add_swaps()))
199+
200+
201+
def test_termination_in_local_minimum():
202+
grid_2x5 = cirq.GridQubit.rect(2, 5)
203+
q = list(cirq.NamedQubit(f'q{i}') for i in range(6))
204+
# The initial mapping looks like:
205+
# _|_0_|_1_|_2_|_3_|_4_|
206+
# 0|q0 | | | |q4 |
207+
# 1|q1 |q2 | |q3 |q5 |
208+
initial_mapping = {
209+
q[0]: cirq.GridQubit(0, 0),
210+
q[1]: cirq.GridQubit(1, 0),
211+
q[2]: cirq.GridQubit(1, 1),
212+
q[3]: cirq.GridQubit(1, 3),
213+
q[4]: cirq.GridQubit(0, 4),
214+
q[5]: cirq.GridQubit(1, 4),
215+
}
216+
# Here's the idea:
217+
# * there are two "clumps" of qubits: (q0,q1,q2) and (q3,q4,q5)
218+
# * the active gate(s) span the two clumps
219+
# * there are also gates on qubits within clumps
220+
# * intra-clump gate cost contribution outweighs inter-clump gate cost
221+
# In that case, we need to swap qubits away from their respective clumps in
222+
# order to progress beyond any of the active gates. But we never will
223+
# because doing so would increase the overall cost due to the intra-clump
224+
# contributions. In fact, no greedy algorithm would be able to make progress
225+
# in this case.
226+
circuit = cirq.Circuit()
227+
# Cross-clump active gates
228+
circuit.append(
229+
[cirq.CNOT(q[0], q[3]),
230+
cirq.CNOT(q[1], q[4]),
231+
cirq.CNOT(q[2], q[5])])
232+
# Intra-clump q0,q1,q2
233+
circuit.append(
234+
[cirq.CNOT(q[0], q[1]),
235+
cirq.CNOT(q[0], q[2]),
236+
cirq.CNOT(q[1], q[2])])
237+
# Intra-clump q3,q4,q5
238+
circuit.append(
239+
[cirq.CNOT(q[3], q[4]),
240+
cirq.CNOT(q[3], q[5]),
241+
cirq.CNOT(q[4], q[5])])
242+
243+
updater = SwapUpdater(circuit, grid_2x5, initial_mapping,
244+
lambda q1, q2: [cirq.SWAP(q1, q2)])
245+
# Iterate until the SwapUpdater is finished or an assertion fails, keeping
246+
# track of the ops generated by the previous iteration.
247+
prev_it = list(updater.update_iteration())
248+
while not updater.dlists.all_empty():
249+
cur_it = list(updater.update_iteration())
250+
251+
# If the current iteration adds a SWAP, it better not be the same SWAP
252+
# as the previous iteration...
253+
# If we pick the same SWAP twice in a row, then we're going in a loop
254+
# forever without making any progress!
255+
def _is_swap(ops):
256+
return len(ops) == 1 and ops[0] == cirq.SWAP(*ops[0].qubits)
257+
258+
if _is_swap(prev_it) and _is_swap(cur_it):
259+
assert set(prev_it[0].qubits) != set(cur_it[0].qubits)
260+
prev_it = cur_it

0 commit comments

Comments
 (0)