Skip to content

Commit 5f63a89

Browse files
sacpisbettinaheim
authored andcommitted
Get dtype from the set target (#2022)
* Determine precision based on the set target * adding a test * fixing spelling in comment
1 parent 59a0843 commit 5f63a89

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

python/cudaq/runtime/state.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# This source code and the accompanying materials are made available under #
66
# the terms of the Apache License 2.0 which accompanies this distribution. #
77
# ============================================================================ #
8+
from ..mlir._mlir_libs._quakeDialects import cudaq_runtime
89

910

1011
def to_cupy(state, dtype=None):
@@ -17,8 +18,11 @@ def to_cupy(state, dtype=None):
1718
except ImportError:
1819
print('to_cupy not supported, CuPy not available. Please install CuPy.')
1920

20-
if dtype == None:
21-
dtype = cp.complex64
21+
if dtype is None:
22+
# Determine the correct data type based on the cudaq target's precision
23+
target = cudaq_runtime.get_target()
24+
precision = target.get_precision()
25+
dtype = cp.complex128 if precision == cudaq_runtime.SimulationPrecision.fp64 else cp.complex64
2226

2327
if not state.is_on_gpu():
2428
raise RuntimeError(

python/tests/builder/test_cupy_integration.py

+14
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,20 @@ def test_cupy_to_state():
107107
assert np.isclose(result, 1.0, atol=1e-3)
108108

109109

110+
def test_cupy_to_state_without_dtype():
111+
cp_data = cp.array([.707107, 0, 0, .707107])
112+
state_from_cupy = cudaq.State.from_data(cp_data)
113+
state_from_cupy.dump()
114+
kernel = cudaq.make_kernel()
115+
q = kernel.qalloc(2)
116+
kernel.h(q[0])
117+
kernel.cx(q[0], q[1])
118+
# State is on the GPU, this is nvidia target
119+
state = cudaq.get_state(kernel)
120+
result = state.overlap(state_from_cupy)
121+
assert np.isclose(result, 1.0, atol=1e-3)
122+
123+
110124
# leave for gdb debugging
111125
if __name__ == "__main__":
112126
loc = os.path.abspath(__file__)

0 commit comments

Comments
 (0)