20
20
from ferminet import networks
21
21
from ferminet import pseudopotential as pp
22
22
from ferminet .utils import utils
23
+ import folx
23
24
import jax
24
25
from jax import lax
25
26
import jax .numpy as jnp
@@ -80,6 +81,7 @@ def local_kinetic_energy(
80
81
f : networks .FermiNetLike ,
81
82
use_scan : bool = False ,
82
83
complex_output : bool = False ,
84
+ laplacian_method : str = 'default' ,
83
85
) -> KineticEnergy :
84
86
r"""Creates a function to for the local kinetic energy, -1/2 \nabla^2 ln|f|.
85
87
@@ -88,6 +90,9 @@ def local_kinetic_energy(
88
90
(sign or phase, log magnitude) tuple.
89
91
use_scan: Whether to use a `lax.scan` for computing the laplacian.
90
92
complex_output: If true, the output of f is complex-valued.
93
+ laplacian_method: Laplacian calculation method. One of:
94
+ 'default': take jvp(grad), looping over inputs
95
+ 'folx': use Microsoft's implementation of forward laplacian
91
96
92
97
Returns:
93
98
Callable which evaluates the local kinetic energy,
@@ -97,51 +102,77 @@ def local_kinetic_energy(
97
102
phase_f = utils .select_output (f , 0 )
98
103
logabs_f = utils .select_output (f , 1 )
99
104
100
- def _lapl_over_f (params , data ):
101
- n = data .positions .shape [0 ]
102
- eye = jnp .eye (n )
103
- grad_f = jax .grad (logabs_f , argnums = 1 )
104
- def grad_f_closure (x ):
105
- return grad_f (params , x , data .spins , data .atoms , data .charges )
106
-
107
- primal , dgrad_f = jax .linearize (grad_f_closure , data .positions )
108
-
105
+ if laplacian_method == 'default' :
106
+
107
+ def _lapl_over_f (params , data ):
108
+ n = data .positions .shape [0 ]
109
+ eye = jnp .eye (n )
110
+ grad_f = jax .grad (logabs_f , argnums = 1 )
111
+ def grad_f_closure (x ):
112
+ return grad_f (params , x , data .spins , data .atoms , data .charges )
113
+
114
+ primal , dgrad_f = jax .linearize (grad_f_closure , data .positions )
115
+
116
+ if complex_output :
117
+ grad_phase = jax .grad (phase_f , argnums = 1 )
118
+ def grad_phase_closure (x ):
119
+ return grad_phase (params , x , data .spins , data .atoms , data .charges )
120
+ phase_primal , dgrad_phase = jax .linearize (
121
+ grad_phase_closure , data .positions )
122
+ hessian_diagonal = (
123
+ lambda i : dgrad_f (eye [i ])[i ] + 1.j * dgrad_phase (eye [i ])[i ]
124
+ )
125
+ else :
126
+ hessian_diagonal = lambda i : dgrad_f (eye [i ])[i ]
127
+
128
+ if use_scan :
129
+ _ , diagonal = lax .scan (
130
+ lambda i , _ : (i + 1 , hessian_diagonal (i )), 0 , None , length = n )
131
+ result = - 0.5 * jnp .sum (diagonal )
132
+ else :
133
+ result = - 0.5 * lax .fori_loop (
134
+ 0 , n , lambda i , val : val + hessian_diagonal (i ), 0.0 )
135
+ result -= 0.5 * jnp .sum (primal ** 2 )
136
+ if complex_output :
137
+ result += 0.5 * jnp .sum (phase_primal ** 2 )
138
+ result -= 1.j * jnp .sum (primal * phase_primal )
139
+ return result
140
+
141
+ elif laplacian_method == 'folx' :
109
142
if complex_output :
110
- grad_phase = jax .grad (phase_f , argnums = 1 )
111
- def grad_phase_closure (x ):
112
- return grad_phase (params , x , data .spins , data .atoms , data .charges )
113
- phase_primal , dgrad_phase = jax .linearize (
114
- grad_phase_closure , data .positions )
115
- hessian_diagonal = (
116
- lambda i : dgrad_f (eye [i ])[i ] + 1.j * dgrad_phase (eye [i ])[i ]
117
- )
143
+ raise NotImplementedError ('Forward laplacian not yet supported for'
144
+ 'complex-valued outputs.' )
118
145
else :
119
- hessian_diagonal = lambda i : dgrad_f (eye [i ])[i ]
120
-
121
- if use_scan :
122
- _ , diagonal = lax .scan (
123
- lambda i , _ : (i + 1 , hessian_diagonal (i )), 0 , None , length = n )
124
- result = - 0.5 * jnp .sum (diagonal )
125
- else :
126
- result = - 0.5 * lax .fori_loop (
127
- 0 , n , lambda i , val : val + hessian_diagonal (i ), 0.0 )
128
- result -= 0.5 * jnp .sum (primal ** 2 )
129
- if complex_output :
130
- result += 0.5 * jnp .sum (phase_primal ** 2 )
131
- result -= 1.j * jnp .sum (primal * phase_primal )
132
- return result
146
+ def _lapl_over_f (params , data ):
147
+ f_closure = lambda x : logabs_f (params ,
148
+ x ,
149
+ data .spins ,
150
+ data .atoms ,
151
+ data .charges )
152
+ f_wrapped = folx .forward_laplacian (f_closure , sparsity_threshold = 6 )
153
+ output = f_wrapped (data .positions )
154
+ return - (output .laplacian +
155
+ jnp .sum (output .jacobian .dense_array ** 2 )) / 2
156
+ else :
157
+ raise NotImplementedError (f'Laplacian method { laplacian_method } '
158
+ 'not implemented.' )
133
159
134
160
return _lapl_over_f
135
161
136
162
137
- def excited_kinetic_energy_matrix (f : networks .FermiNetLike ,
138
- states : int ) -> KineticEnergy :
163
+ def excited_kinetic_energy_matrix (
164
+ f : networks .FermiNetLike ,
165
+ states : int ,
166
+ laplacian_method : str = 'default' ) -> KineticEnergy :
139
167
"""Creates a f'n which evaluates the matrix of local kinetic energies.
140
168
141
169
Args:
142
170
f: A network which returns a tuple of sign(psi) and log(|psi|) arrays, where
143
171
each array contains one element per excited state.
144
172
states: the number of excited states
173
+ laplacian_method: Laplacian calculation method. One of:
174
+ 'default': take jvp(grad), looping over inputs
175
+ 'folx': use Microsoft's implementation of forward laplacian
145
176
146
177
Returns:
147
178
A function which computes the matrices (psi) and (K psi), which are the
@@ -166,11 +197,24 @@ def _lapl_over_f(params, data):
166
197
"""Return the kinetic energy (divided by psi) summed over excited states."""
167
198
pos_ = jnp .reshape (data .positions , [states , - 1 ])
168
199
spins_ = jnp .reshape (data .spins , [states , - 1 ])
169
- vmap_f = jax .vmap (f , (None , 0 , 0 , None , None ))
170
- sign_mat , log_mat = vmap_f (params , pos_ , spins_ , data .atoms , data .charges )
171
- vmap_lapl = jax .vmap (_lapl_all_states , (None , 0 , 0 , None , None ))
172
- lapl = vmap_lapl (params , pos_ , spins_ , data .atoms ,
173
- data .charges ) # K psi_i(r_j) / psi_i(r_j)
200
+
201
+ if laplacian_method == 'default' :
202
+ vmap_f = jax .vmap (f , (None , 0 , 0 , None , None ))
203
+ sign_mat , log_mat = vmap_f (params , pos_ , spins_ , data .atoms , data .charges )
204
+ vmap_lapl = jax .vmap (_lapl_all_states , (None , 0 , 0 , None , None ))
205
+ lapl = vmap_lapl (params , pos_ , spins_ , data .atoms ,
206
+ data .charges ) # K psi_i(r_j) / psi_i(r_j)
207
+ elif laplacian_method == 'folx' :
208
+ # CAUTION!! Only the first array of spins is being passed!
209
+ f_closure = lambda x : f (params , x , spins_ [0 ], data .atoms , data .charges )
210
+ f_wrapped = folx .forward_laplacian (f_closure , sparsity_threshold = 6 )
211
+ sign_mat , log_out = folx .batched_vmap (f_wrapped , 1 )(pos_ )
212
+ log_mat = log_out .x
213
+ lapl = - (log_out .laplacian +
214
+ jnp .sum (log_out .jacobian .dense_array ** 2 , axis = - 2 )) / 2
215
+ else :
216
+ raise NotImplementedError (f'Laplacian method { laplacian_method } '
217
+ 'not implemented with excited states.' )
174
218
175
219
# subtract off largest value to avoid under/overflow
176
220
psi_mat = sign_mat * jnp .exp (log_mat - jnp .max (log_mat )) # psi_i(r_j)
@@ -239,6 +283,7 @@ def local_energy(
239
283
nspins : Sequence [int ],
240
284
use_scan : bool = False ,
241
285
complex_output : bool = False ,
286
+ laplacian_method : str = 'default' ,
242
287
states : int = 0 ,
243
288
pp_type : str = 'ccecp' ,
244
289
pp_symbols : Sequence [str ] | None = None ,
@@ -252,6 +297,9 @@ def local_energy(
252
297
nspins: Number of particles of each spin.
253
298
use_scan: Whether to use a `lax.scan` for computing the laplacian.
254
299
complex_output: If true, the output of f is complex-valued.
300
+ laplacian_method: Laplacian calculation method. One of:
301
+ 'default': take jvp(grad), looping over inputs
302
+ 'folx': use Microsoft's implementation of forward laplacian
255
303
states: Number of excited states to compute. If 0, compute ground state with
256
304
default machinery. If 1, compute ground state with excited state machinery
257
305
pp_type: type of pseudopotential to use. Only used if ecp_symbols is
@@ -270,11 +318,12 @@ def local_energy(
270
318
del nspins
271
319
272
320
if states :
273
- ke = excited_kinetic_energy_matrix (f , states )
321
+ ke = excited_kinetic_energy_matrix (f , states , laplacian_method )
274
322
else :
275
323
ke = local_kinetic_energy (f ,
276
324
use_scan = use_scan ,
277
- complex_output = complex_output )
325
+ complex_output = complex_output ,
326
+ laplacian_method = laplacian_method )
278
327
279
328
if not pp_symbols :
280
329
effective_charges = charges
0 commit comments