Skip to content

Commit

Permalink
torchify custom callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrava committed Feb 14, 2025
1 parent 922b66c commit c84c786
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions emu_sv/custom_callback_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
def qubit_density_sv_impl(
self: QubitDensity, config: BackendConfig, t: int, state: StateVector, H: Operator
) -> Any:

"""
Custom implementation of the qubit density ❬ψ|nᵢ|ψ❭
for the state vector solver.
"""
num_qubits = int(math.log2(len(state.vector)))
state_tensor = state.vector.reshape((2,) * num_qubits)
return [(state_tensor.select(i, 1).norm() ** 2).item() for i in range(num_qubits)]
return [state_tensor.select(i, 1).norm() ** 2 for i in range(num_qubits)]


def correlation_matrix_sv_impl(
Expand All @@ -32,7 +35,12 @@ def correlation_matrix_sv_impl(
state: StateVector,
H: Operator,
) -> Any:
"""'Sparse' implementation of <𝜓| nᵢ nⱼ | 𝜓 >"""
"""
Custom implementation of the density-density correlation ❬ψ|nᵢnⱼ|ψ❭
for the state vector solver.
TODO: extend to arbitrary two-point correlation ❬ψ|AᵢBⱼ|ψ❭
"""
num_qubits = int(math.log2(len(state.vector)))
state_tensor = state.vector.reshape((2,) * num_qubits)

Expand All @@ -42,9 +50,9 @@ def correlation_matrix_sv_impl(
select_i = state_tensor.select(numi, 1)
for numj in range(numi, num_qubits): # select the upper triangle
if numi == numj:
value = (select_i.norm() ** 2).item()
value = select_i.norm() ** 2
else:
value = (select_i.select(numj - 1, 1).norm() ** 2).item()
value = select_i.select(numj - 1, 1).norm() ** 2

correlation_matrix[numi][numj] = value
correlation_matrix[numj][numi] = value
Expand All @@ -58,10 +66,14 @@ def energy_variance_sv_impl(
state: StateVector,
H: RydbergHamiltonian,
) -> Any:
"""
Custom implementation of the energy variance ❬ψ|H²|ψ❭-❬ψ|H|ψ❭²
for the state vector solver.
"""
hstate = H * state.vector
h_squared = torch.vdot(hstate, hstate)
h_state = torch.vdot(state.vector, hstate)
return (h_squared.real - h_state.real**2).item()
return h_squared.real - h_state.real**2


def second_moment_sv_impl(
Expand All @@ -71,7 +83,10 @@ def second_moment_sv_impl(
state: StateVector,
H: RydbergHamiltonian,
) -> Any:

"""
Custom implementation of the second moment of energy ❬ψ|H²|ψ❭
for the state vector solver.
"""
hstate = H * state.vector
h_squared = torch.vdot(hstate, hstate)
return (h_squared.real).item()
return h_squared.real

0 comments on commit c84c786

Please sign in to comment.