Skip to content

Commit

Permalink
fix shots dm
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles MOUSSA committed Dec 18, 2024
1 parent 3a403a9 commit 49172fc
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
62 changes: 55 additions & 7 deletions horqrux/shots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from functools import partial, reduce
from functools import partial, reduce, singledispatch
from typing import Any

import jax
Expand Down Expand Up @@ -37,8 +37,57 @@ def observable_to_matrix(
return reduce(lambda x, y: jnp.kron(x, y), ops[1:], ops[0])


@singledispatch
def probs_from_eigenvectors_state(state: Array, eigvecs: Array) -> Array:
"""Obtain the probabilities using an input state and the eigenvectors decomposition
of an observable.
Args:
state (Array): Input array.
eigvecs (Array): Eigenvectors of the observables.
Returns:
Array: The probabilities.
"""
raise NotImplementedError("prod_eigenvectors_state is not implemented")


@probs_from_eigenvectors_state.register
def _(state: Array, eigvecs: Array) -> Array:
"""Obtain the probabilities using an input quantum state vector
and the eigenvectors decomposition
of an observable.
Args:
state (Array): Input array.
eigvecs (Array): Eigenvectors of the observables.
Returns:
Array: The probabilities.
"""
inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten())
return jnp.abs(inner_prod) ** 2


@probs_from_eigenvectors_state.register
def _(state: DensityMatrix, eigvecs: Array) -> Array:
"""Obtain the probabilities using an input quantum density matrix
and the eigenvectors decomposition
of an observable.
Args:
state (DensityMatrix): Input array.
eigvecs (Array): Eigenvectors of the observables.
Returns:
Array: The probabilities.
"""
mat_prob = jnp.conjugate(eigvecs.T) @ state.array @ eigvecs
return mat_prob.diagonal().real


def eigenval_decomposition_sampling(
state: Array,
state: Array | DensityMatrix,
observables: list[Primitive],
values: dict[str, float],
n_qubits: int,
Expand All @@ -48,8 +97,7 @@ def eigenval_decomposition_sampling(
mat_obs = [observable_to_matrix(observable, n_qubits, values) for observable in observables]
eigs = [jnp.linalg.eigh(mat) for mat in mat_obs]
eigvecs, eigvals = align_eigenvectors(eigs)
inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten())
probs = jnp.abs(inner_prod) ** 2
probs = probs_from_eigenvectors_state(state, eigvecs)
return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0)


Expand All @@ -67,10 +115,10 @@ def finite_shots_fwd(
and compute the expectation given an observable.
"""
if isinstance(state, DensityMatrix):
output_gates = apply_gate(state, gates, values).array
n_qubits = len(output_gates.shape) // 2
output_gates = apply_gate(state, gates, values)
n_qubits = len(output_gates.array.shape) // 2
d = 2**n_qubits
output_gates = output_gates.reshape((d, d))
output_gates.array = output_gates.array.reshape((d, d))
else:
output_gates = apply_gate(state, gates, values)
n_qubits = len(state.shape)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_shots.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ def shots_dm(x):
assert jnp.allclose(exp_exact, exp_exact_dm)

exp_shots = shots(x)
# FIXME: DM expectation not working
# exp_shots_dm = shots_dm(x)
exp_shots_dm = shots_dm(x)

assert jnp.allclose(exp_exact, exp_shots, atol=SHOTS_ATOL)
# assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL)
assert jnp.allclose(exp_exact, exp_shots_dm, atol=SHOTS_ATOL)

d_exact = jax.grad(lambda x: exact(x).sum())
d_shots = jax.grad(lambda x: shots(x).sum())
Expand Down

0 comments on commit 49172fc

Please sign in to comment.