Skip to content

Commit 502ae34

Browse files
authored
Make it easier to access dumps and matrices from Python (#2042)
This adds some convenience accessors for getting dumps, messages, and matrices as separate entries from an invocation of `qsharp.run` and `qsharp.eval`. Previously, the only way to get a state dump from a run was the awkward: ```python state = qsharp.StateDump(qsharp.run("DumpMachine()", shots=1, save_events=True)[0]["events"][0].state_dump()) ``` This change preserves the existings "events" entry in the saved output, which has everything intermingled in the order from each shot, but also introduces dumps, messages, and matrices that will keep just the ordered output of that type. This makes the above pattern slightly better (and more discoverable): ```python state = qsharp.run("DumpMachine()", shots=1, save_events=True)[0]["dumps"][0] ``` This adds similar functionality to `qsharp.eval` which now supports `save_events=True` to capture output, so for single shot execution you can use: ```python state = qsharp.eval("DumpMachine()", save_events=True)["dumps"][0] ```
1 parent 609bef5 commit 502ae34

File tree

4 files changed

+180
-91
lines changed

4 files changed

+180
-91
lines changed

pip/qsharp/_native.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ class Output:
262262
def __str__(self) -> str: ...
263263
def _repr_markdown_(self) -> Optional[str]: ...
264264
def state_dump(self) -> Optional[StateDumpData]: ...
265+
def is_state_dump(self) -> bool: ...
266+
def is_matrix(self) -> bool: ...
267+
def is_message(self) -> bool: ...
265268

266269
class StateDumpData:
267270
"""

pip/qsharp/_qsharp.py

+133-90
Original file line numberDiff line numberDiff line change
@@ -219,18 +219,133 @@ def get_interpreter() -> Interpreter:
219219
return _interpreter
220220

221221

222-
def eval(source: str) -> Any:
222+
class StateDump:
223+
"""
224+
A state dump returned from the Q# interpreter.
225+
"""
226+
227+
"""
228+
The number of allocated qubits at the time of the dump.
229+
"""
230+
qubit_count: int
231+
232+
__inner: dict
233+
__data: StateDumpData
234+
235+
def __init__(self, data: StateDumpData):
236+
self.__data = data
237+
self.__inner = data.get_dict()
238+
self.qubit_count = data.qubit_count
239+
240+
def __getitem__(self, index: int) -> complex:
241+
return self.__inner.__getitem__(index)
242+
243+
def __iter__(self):
244+
return self.__inner.__iter__()
245+
246+
def __len__(self) -> int:
247+
return len(self.__inner)
248+
249+
def __repr__(self) -> str:
250+
return self.__data.__repr__()
251+
252+
def __str__(self) -> str:
253+
return self.__data.__str__()
254+
255+
def _repr_markdown_(self) -> str:
256+
return self.__data._repr_markdown_()
257+
258+
def check_eq(
259+
self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10
260+
) -> bool:
261+
"""
262+
Checks if the state dump is equal to the given state. This is not mathematical equality,
263+
as the check ignores global phase.
264+
265+
:param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes,
266+
or as a list of real amplitudes.
267+
:param tolerance: The tolerance for the check. Defaults to 1e-10.
268+
"""
269+
phase = None
270+
# Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes
271+
if isinstance(state, list):
272+
state = {i: val for i, val in enumerate(state)}
273+
# Filter out zero states from the state dump and the given state based on tolerance
274+
state = {k: v for k, v in state.items() if abs(v) > tolerance}
275+
inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance}
276+
if len(state) != len(inner_state):
277+
return False
278+
for key in state:
279+
if key not in inner_state:
280+
return False
281+
if phase is None:
282+
# Calculate the phase based on the first state pair encountered.
283+
# Every pair of states after this must have the same phase for the states to be equivalent.
284+
phase = inner_state[key] / state[key]
285+
elif abs(phase - inner_state[key] / state[key]) > tolerance:
286+
# This pair of states does not have the same phase,
287+
# within tolerance, so the equivalence check fails.
288+
return False
289+
return True
290+
291+
def as_dense_state(self) -> List[complex]:
292+
"""
293+
Returns the state dump as a dense list of complex amplitudes. This will include zero amplitudes.
294+
"""
295+
return [self.__inner.get(i, complex(0)) for i in range(2**self.qubit_count)]
296+
297+
298+
class ShotResult(TypedDict):
299+
"""
300+
A single result of a shot.
301+
"""
302+
303+
events: List[Output]
304+
result: Any
305+
messages: List[str]
306+
matrices: List[Output]
307+
dumps: List[StateDump]
308+
309+
310+
def eval(
311+
source: str,
312+
*,
313+
save_events: bool = False,
314+
) -> Any:
223315
"""
224316
Evaluates Q# source code.
225317
226318
Output is printed to console.
227319
228320
:param source: The Q# source code to evaluate.
229-
:returns value: The value returned by the last statement in the source code.
321+
:param save_events: If true, all output will be saved and returned. If false, they will be printed.
322+
:returns value: The value returned by the last statement in the source code or the saved output if `save_events` is true.
230323
:raises QSharpError: If there is an error evaluating the source code.
231324
"""
232325
ipython_helper()
233326

327+
results: ShotResult = {
328+
"events": [],
329+
"result": None,
330+
"messages": [],
331+
"matrices": [],
332+
"dumps": [],
333+
}
334+
335+
def on_save_events(output: Output) -> None:
336+
# Append the output to the last shot's output list
337+
if output.is_matrix():
338+
results["events"].append(output)
339+
results["matrices"].append(output)
340+
elif output.is_state_dump():
341+
state_dump = StateDump(output.state_dump())
342+
results["events"].append(state_dump)
343+
results["dumps"].append(state_dump)
344+
elif output.is_message():
345+
stringified = str(output)
346+
results["events"].append(stringified)
347+
results["messages"].append(stringified)
348+
234349
def callback(output: Output) -> None:
235350
if _in_jupyter:
236351
try:
@@ -244,21 +359,17 @@ def callback(output: Output) -> None:
244359
telemetry_events.on_eval()
245360
start_time = monotonic()
246361

247-
results = get_interpreter().interpret(source, callback)
362+
results["result"] = get_interpreter().interpret(
363+
source, on_save_events if save_events else callback
364+
)
248365

249366
durationMs = (monotonic() - start_time) * 1000
250367
telemetry_events.on_eval_end(durationMs)
251368

252-
return results
253-
254-
255-
class ShotResult(TypedDict):
256-
"""
257-
A single result of a shot.
258-
"""
259-
260-
events: List[Output]
261-
result: Any
369+
if save_events:
370+
return results
371+
else:
372+
return results["result"]
262373

263374

264375
def run(
@@ -315,9 +426,17 @@ def print_output(output: Output) -> None:
315426
def on_save_events(output: Output) -> None:
316427
# Append the output to the last shot's output list
317428
results[-1]["events"].append(output)
429+
if output.is_matrix():
430+
results[-1]["matrices"].append(output)
431+
elif output.is_state_dump():
432+
results[-1]["dumps"].append(StateDump(output.state_dump()))
433+
elif output.is_message():
434+
results[-1]["messages"].append(str(output))
318435

319436
for shot in range(shots):
320-
results.append({"result": None, "events": []})
437+
results.append(
438+
{"result": None, "events": [], "messages": [], "matrices": [], "dumps": []}
439+
)
321440
run_results = get_interpreter().run(
322441
entry_expr,
323442
on_save_events if save_events else print_output,
@@ -482,82 +601,6 @@ def set_classical_seed(seed: Optional[int]) -> None:
482601
get_interpreter().set_classical_seed(seed)
483602

484603

485-
class StateDump:
486-
"""
487-
A state dump returned from the Q# interpreter.
488-
"""
489-
490-
"""
491-
The number of allocated qubits at the time of the dump.
492-
"""
493-
qubit_count: int
494-
495-
__inner: dict
496-
__data: StateDumpData
497-
498-
def __init__(self, data: StateDumpData):
499-
self.__data = data
500-
self.__inner = data.get_dict()
501-
self.qubit_count = data.qubit_count
502-
503-
def __getitem__(self, index: int) -> complex:
504-
return self.__inner.__getitem__(index)
505-
506-
def __iter__(self):
507-
return self.__inner.__iter__()
508-
509-
def __len__(self) -> int:
510-
return len(self.__inner)
511-
512-
def __repr__(self) -> str:
513-
return self.__data.__repr__()
514-
515-
def __str__(self) -> str:
516-
return self.__data.__str__()
517-
518-
def _repr_markdown_(self) -> str:
519-
return self.__data._repr_markdown_()
520-
521-
def check_eq(
522-
self, state: Union[Dict[int, complex], List[complex]], tolerance: float = 1e-10
523-
) -> bool:
524-
"""
525-
Checks if the state dump is equal to the given state. This is not mathematical equality,
526-
as the check ignores global phase.
527-
528-
:param state: The state to check against, provided either as a dictionary of state indices to complex amplitudes,
529-
or as a list of real amplitudes.
530-
:param tolerance: The tolerance for the check. Defaults to 1e-10.
531-
"""
532-
phase = None
533-
# Convert a dense list of real amplitudes to a dictionary of state indices to complex amplitudes
534-
if isinstance(state, list):
535-
state = {i: state[i] for i in range(len(state))}
536-
# Filter out zero states from the state dump and the given state based on tolerance
537-
state = {k: v for k, v in state.items() if abs(v) > tolerance}
538-
inner_state = {k: v for k, v in self.__inner.items() if abs(v) > tolerance}
539-
if len(state) != len(inner_state):
540-
return False
541-
for key in state:
542-
if key not in inner_state:
543-
return False
544-
if phase is None:
545-
# Calculate the phase based on the first state pair encountered.
546-
# Every pair of states after this must have the same phase for the states to be equivalent.
547-
phase = inner_state[key] / state[key]
548-
elif abs(phase - inner_state[key] / state[key]) > tolerance:
549-
# This pair of states does not have the same phase,
550-
# within tolerance, so the equivalence check fails.
551-
return False
552-
return True
553-
554-
def as_dense_state(self) -> List[complex]:
555-
"""
556-
Returns the state dump as a dense list of complex amplitudes. This will include zero amplitudes.
557-
"""
558-
return [self.__inner.get(i, complex(0)) for i in range(2**self.qubit_count)]
559-
560-
561604
def dump_machine() -> StateDump:
562605
"""
563606
Returns the sparse state vector of the simulator as a StateDump object.

pip/src/interpreter.rs

+12
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,18 @@ impl Output {
587587
DisplayableOutput::Matrix(_) | DisplayableOutput::Message(_) => None,
588588
}
589589
}
590+
591+
fn is_state_dump(&self) -> bool {
592+
matches!(&self.0, DisplayableOutput::State(_))
593+
}
594+
595+
fn is_matrix(&self) -> bool {
596+
matches!(&self.0, DisplayableOutput::Matrix(_))
597+
}
598+
599+
fn is_message(&self) -> bool {
600+
matches!(&self.0, DisplayableOutput::Message(_))
601+
}
590602
}
591603

592604
#[pyclass]

pip/tests/test_qsharp.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,35 @@ def test_stdout_multiple_lines() -> None:
3535
assert f.getvalue() == "STATE:\n|0⟩: 1.0000+0.0000𝑖\nHello!\n"
3636

3737

38+
def test_captured_stdout() -> None:
39+
qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted)
40+
f = io.StringIO()
41+
with redirect_stdout(f):
42+
result = qsharp.eval(
43+
'{Message("Hello, world!"); Message("Goodbye!")}', save_events=True
44+
)
45+
assert f.getvalue() == ""
46+
assert len(result["messages"]) == 2
47+
assert result["messages"][0] == "Hello, world!"
48+
assert result["messages"][1] == "Goodbye!"
49+
50+
51+
def test_captured_matrix() -> None:
52+
qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted)
53+
f = io.StringIO()
54+
with redirect_stdout(f):
55+
result = qsharp.eval(
56+
"Std.Diagnostics.DumpOperation(1, qs => H(qs[0]))",
57+
save_events=True,
58+
)
59+
assert f.getvalue() == ""
60+
assert len(result["matrices"]) == 1
61+
assert (
62+
str(result["matrices"][0])
63+
== "MATRIX:\n 0.7071+0.0000𝑖 0.7071+0.0000𝑖\n 0.7071+0.0000𝑖 −0.7071+0.0000𝑖"
64+
)
65+
66+
3867
def test_quantum_seed() -> None:
3968
qsharp.init(target_profile=qsharp.TargetProfile.Unrestricted)
4069
qsharp.set_quantum_seed(42)
@@ -257,6 +286,7 @@ def test_dump_operation() -> None:
257286
else:
258287
assert res[i][j] == complex(0.0, 0.0)
259288

289+
260290
def test_run_with_noise_produces_noisy_results() -> None:
261291
qsharp.init()
262292
qsharp.set_quantum_seed(0)
@@ -273,6 +303,7 @@ def test_run_with_noise_produces_noisy_results() -> None:
273303
)
274304
assert result[0] > 5
275305

306+
276307
def test_compile_qir_input_data() -> None:
277308
qsharp.init(target_profile=qsharp.TargetProfile.Base)
278309
qsharp.eval("operation Program() : Result { use q = Qubit(); return M(q) }")
@@ -324,7 +355,7 @@ def on_result(result):
324355
results = qsharp.run("Foo()", 3, on_result=on_result, save_events=True)
325356
assert (
326357
str(results)
327-
== "[{'result': Zero, 'events': [Hello, world!]}, {'result': Zero, 'events': [Hello, world!]}, {'result': Zero, 'events': [Hello, world!]}]"
358+
== "[{'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}, {'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}, {'result': Zero, 'events': [Hello, world!], 'messages': ['Hello, world!'], 'matrices': [], 'dumps': []}]"
328359
)
329360
stdout = capsys.readouterr().out
330361
assert stdout == ""

0 commit comments

Comments
 (0)