Skip to content

Commit e94435b

Browse files
dpfaujsspencer
authored andcommitted
Integrate folx for forward Laplacian computation
PiperOrigin-RevId: 592595762 Change-Id: I2209521359899fe2c2fda232c9e0a2c68359331d
1 parent bb979a0 commit e94435b

File tree

6 files changed

+127
-50
lines changed

6 files changed

+127
-50
lines changed

ferminet/base_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def default() -> ml_collections.ConfigDict:
5656
'objective': 'vmc', # objective type. Either 'vmc' or 'wqmc'
5757
'iterations': 1000000, # number of iterations
5858
'optimizer': 'kfac', # one of adam, kfac, lamb, none
59+
'laplacian': 'default', # of of default or folx (for forward lapl)
5960
'lr': {
6061
'rate': 0.05, # learning rate
6162
'decay': 1.0, # exponent of learning rate decay

ferminet/hamiltonian.py

+89-40
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ferminet import networks
2121
from ferminet import pseudopotential as pp
2222
from ferminet.utils import utils
23+
import folx
2324
import jax
2425
from jax import lax
2526
import jax.numpy as jnp
@@ -80,6 +81,7 @@ def local_kinetic_energy(
8081
f: networks.FermiNetLike,
8182
use_scan: bool = False,
8283
complex_output: bool = False,
84+
laplacian_method: str = 'default',
8385
) -> KineticEnergy:
8486
r"""Creates a function to for the local kinetic energy, -1/2 \nabla^2 ln|f|.
8587
@@ -88,6 +90,9 @@ def local_kinetic_energy(
8890
(sign or phase, log magnitude) tuple.
8991
use_scan: Whether to use a `lax.scan` for computing the laplacian.
9092
complex_output: If true, the output of f is complex-valued.
93+
laplacian_method: Laplacian calculation method. One of:
94+
'default': take jvp(grad), looping over inputs
95+
'folx': use Microsoft's implementation of forward laplacian
9196
9297
Returns:
9398
Callable which evaluates the local kinetic energy,
@@ -97,51 +102,77 @@ def local_kinetic_energy(
97102
phase_f = utils.select_output(f, 0)
98103
logabs_f = utils.select_output(f, 1)
99104

100-
def _lapl_over_f(params, data):
101-
n = data.positions.shape[0]
102-
eye = jnp.eye(n)
103-
grad_f = jax.grad(logabs_f, argnums=1)
104-
def grad_f_closure(x):
105-
return grad_f(params, x, data.spins, data.atoms, data.charges)
106-
107-
primal, dgrad_f = jax.linearize(grad_f_closure, data.positions)
108-
105+
if laplacian_method == 'default':
106+
107+
def _lapl_over_f(params, data):
108+
n = data.positions.shape[0]
109+
eye = jnp.eye(n)
110+
grad_f = jax.grad(logabs_f, argnums=1)
111+
def grad_f_closure(x):
112+
return grad_f(params, x, data.spins, data.atoms, data.charges)
113+
114+
primal, dgrad_f = jax.linearize(grad_f_closure, data.positions)
115+
116+
if complex_output:
117+
grad_phase = jax.grad(phase_f, argnums=1)
118+
def grad_phase_closure(x):
119+
return grad_phase(params, x, data.spins, data.atoms, data.charges)
120+
phase_primal, dgrad_phase = jax.linearize(
121+
grad_phase_closure, data.positions)
122+
hessian_diagonal = (
123+
lambda i: dgrad_f(eye[i])[i] + 1.j * dgrad_phase(eye[i])[i]
124+
)
125+
else:
126+
hessian_diagonal = lambda i: dgrad_f(eye[i])[i]
127+
128+
if use_scan:
129+
_, diagonal = lax.scan(
130+
lambda i, _: (i + 1, hessian_diagonal(i)), 0, None, length=n)
131+
result = -0.5 * jnp.sum(diagonal)
132+
else:
133+
result = -0.5 * lax.fori_loop(
134+
0, n, lambda i, val: val + hessian_diagonal(i), 0.0)
135+
result -= 0.5 * jnp.sum(primal ** 2)
136+
if complex_output:
137+
result += 0.5 * jnp.sum(phase_primal ** 2)
138+
result -= 1.j * jnp.sum(primal * phase_primal)
139+
return result
140+
141+
elif laplacian_method == 'folx':
109142
if complex_output:
110-
grad_phase = jax.grad(phase_f, argnums=1)
111-
def grad_phase_closure(x):
112-
return grad_phase(params, x, data.spins, data.atoms, data.charges)
113-
phase_primal, dgrad_phase = jax.linearize(
114-
grad_phase_closure, data.positions)
115-
hessian_diagonal = (
116-
lambda i: dgrad_f(eye[i])[i] + 1.j * dgrad_phase(eye[i])[i]
117-
)
143+
raise NotImplementedError('Forward laplacian not yet supported for'
144+
'complex-valued outputs.')
118145
else:
119-
hessian_diagonal = lambda i: dgrad_f(eye[i])[i]
120-
121-
if use_scan:
122-
_, diagonal = lax.scan(
123-
lambda i, _: (i + 1, hessian_diagonal(i)), 0, None, length=n)
124-
result = -0.5 * jnp.sum(diagonal)
125-
else:
126-
result = -0.5 * lax.fori_loop(
127-
0, n, lambda i, val: val + hessian_diagonal(i), 0.0)
128-
result -= 0.5 * jnp.sum(primal ** 2)
129-
if complex_output:
130-
result += 0.5 * jnp.sum(phase_primal ** 2)
131-
result -= 1.j * jnp.sum(primal * phase_primal)
132-
return result
146+
def _lapl_over_f(params, data):
147+
f_closure = lambda x: logabs_f(params,
148+
x,
149+
data.spins,
150+
data.atoms,
151+
data.charges)
152+
f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6)
153+
output = f_wrapped(data.positions)
154+
return - (output.laplacian +
155+
jnp.sum(output.jacobian.dense_array ** 2)) / 2
156+
else:
157+
raise NotImplementedError(f'Laplacian method {laplacian_method} '
158+
'not implemented.')
133159

134160
return _lapl_over_f
135161

136162

137-
def excited_kinetic_energy_matrix(f: networks.FermiNetLike,
138-
states: int) -> KineticEnergy:
163+
def excited_kinetic_energy_matrix(
164+
f: networks.FermiNetLike,
165+
states: int,
166+
laplacian_method: str = 'default') -> KineticEnergy:
139167
"""Creates a f'n which evaluates the matrix of local kinetic energies.
140168
141169
Args:
142170
f: A network which returns a tuple of sign(psi) and log(|psi|) arrays, where
143171
each array contains one element per excited state.
144172
states: the number of excited states
173+
laplacian_method: Laplacian calculation method. One of:
174+
'default': take jvp(grad), looping over inputs
175+
'folx': use Microsoft's implementation of forward laplacian
145176
146177
Returns:
147178
A function which computes the matrices (psi) and (K psi), which are the
@@ -166,11 +197,24 @@ def _lapl_over_f(params, data):
166197
"""Return the kinetic energy (divided by psi) summed over excited states."""
167198
pos_ = jnp.reshape(data.positions, [states, -1])
168199
spins_ = jnp.reshape(data.spins, [states, -1])
169-
vmap_f = jax.vmap(f, (None, 0, 0, None, None))
170-
sign_mat, log_mat = vmap_f(params, pos_, spins_, data.atoms, data.charges)
171-
vmap_lapl = jax.vmap(_lapl_all_states, (None, 0, 0, None, None))
172-
lapl = vmap_lapl(params, pos_, spins_, data.atoms,
173-
data.charges) # K psi_i(r_j) / psi_i(r_j)
200+
201+
if laplacian_method == 'default':
202+
vmap_f = jax.vmap(f, (None, 0, 0, None, None))
203+
sign_mat, log_mat = vmap_f(params, pos_, spins_, data.atoms, data.charges)
204+
vmap_lapl = jax.vmap(_lapl_all_states, (None, 0, 0, None, None))
205+
lapl = vmap_lapl(params, pos_, spins_, data.atoms,
206+
data.charges) # K psi_i(r_j) / psi_i(r_j)
207+
elif laplacian_method == 'folx':
208+
# CAUTION!! Only the first array of spins is being passed!
209+
f_closure = lambda x: f(params, x, spins_[0], data.atoms, data.charges)
210+
f_wrapped = folx.forward_laplacian(f_closure, sparsity_threshold=6)
211+
sign_mat, log_out = folx.batched_vmap(f_wrapped, 1)(pos_)
212+
log_mat = log_out.x
213+
lapl = -(log_out.laplacian +
214+
jnp.sum(log_out.jacobian.dense_array ** 2, axis=-2)) / 2
215+
else:
216+
raise NotImplementedError(f'Laplacian method {laplacian_method} '
217+
'not implemented with excited states.')
174218

175219
# subtract off largest value to avoid under/overflow
176220
psi_mat = sign_mat * jnp.exp(log_mat - jnp.max(log_mat)) # psi_i(r_j)
@@ -239,6 +283,7 @@ def local_energy(
239283
nspins: Sequence[int],
240284
use_scan: bool = False,
241285
complex_output: bool = False,
286+
laplacian_method: str = 'default',
242287
states: int = 0,
243288
pp_type: str = 'ccecp',
244289
pp_symbols: Sequence[str] | None = None,
@@ -252,6 +297,9 @@ def local_energy(
252297
nspins: Number of particles of each spin.
253298
use_scan: Whether to use a `lax.scan` for computing the laplacian.
254299
complex_output: If true, the output of f is complex-valued.
300+
laplacian_method: Laplacian calculation method. One of:
301+
'default': take jvp(grad), looping over inputs
302+
'folx': use Microsoft's implementation of forward laplacian
255303
states: Number of excited states to compute. If 0, compute ground state with
256304
default machinery. If 1, compute ground state with excited state machinery
257305
pp_type: type of pseudopotential to use. Only used if ecp_symbols is
@@ -270,11 +318,12 @@ def local_energy(
270318
del nspins
271319

272320
if states:
273-
ke = excited_kinetic_energy_matrix(f, states)
321+
ke = excited_kinetic_energy_matrix(f, states, laplacian_method)
274322
else:
275323
ke = local_kinetic_energy(f,
276324
use_scan=use_scan,
277-
complex_output=complex_output)
325+
complex_output=complex_output,
326+
laplacian_method=laplacian_method)
278327

279328
if not pp_symbols:
280329
effective_charges = charges

ferminet/tests/hamiltonian_test.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Tests for ferminet.hamiltonian."""
1616

17+
import itertools
18+
1719
from absl.testing import absltest
1820
from absl.testing import parameterized
1921
from ferminet import base_config
@@ -59,7 +61,8 @@ def kinetic_operator(params, pos, spins, atoms, charges):
5961

6062
class HamiltonianTest(parameterized.TestCase):
6163

62-
def test_local_kinetic_energy(self):
64+
@parameterized.parameters(['default', 'folx'])
65+
def test_local_kinetic_energy(self, laplacian):
6366

6467
dummy_params = {}
6568
xs = np.random.normal(size=(3,))
@@ -68,7 +71,8 @@ def test_local_kinetic_energy(self):
6871
charges = 2 * np.ones(shape=(1,))
6972
expected_kinetic_energy = -(1 - 2 / np.abs(np.linalg.norm(xs))) / 2
7073

71-
kinetic = hamiltonian.local_kinetic_energy(h_atom_log_psi_signed)
74+
kinetic = hamiltonian.local_kinetic_energy(h_atom_log_psi_signed,
75+
laplacian_method=laplacian)
7276
kinetic_energy = kinetic(
7377
dummy_params,
7478
networks.FermiNetData(
@@ -152,7 +156,8 @@ def test_local_energy(self):
152156

153157
class LaplacianTest(parameterized.TestCase):
154158

155-
def test_laplacian(self):
159+
@parameterized.parameters(['default', 'folx'])
160+
def test_laplacian(self, laplacian):
156161

157162
xs = np.random.uniform(size=(100, 3))
158163
spins = np.ones(shape=(1,))
@@ -163,7 +168,8 @@ def test_laplacian(self):
163168
)
164169
dummy_params = {}
165170
t_l_fn = jax.vmap(
166-
hamiltonian.local_kinetic_energy(h_atom_log_psi_signed),
171+
hamiltonian.local_kinetic_energy(h_atom_log_psi_signed,
172+
laplacian_method=laplacian),
167173
in_axes=(
168174
None,
169175
networks.FermiNetData(
@@ -178,8 +184,10 @@ def test_laplacian(self):
178184
)(dummy_params, xs, spins, atoms, charges)
179185
np.testing.assert_allclose(t_l, hess_t, rtol=1E-5)
180186

181-
@parameterized.parameters([True, False])
182-
def test_fermi_net_laplacian(self, full_det):
187+
@parameterized.parameters(
188+
itertools.product([True, False], ['default', 'folx'])
189+
)
190+
def test_fermi_net_laplacian(self, full_det, laplacian):
183191
natoms = 2
184192
np.random.seed(12)
185193
atoms = np.random.uniform(low=-5.0, high=5.0, size=(natoms, 3))
@@ -209,7 +217,8 @@ def test_fermi_net_laplacian(self, full_det):
209217
spins = np.sign(np.random.normal(scale=1, size=(batch, sum(nspins))))
210218
t_l_fn = jax.jit(
211219
jax.vmap(
212-
hamiltonian.local_kinetic_energy(network.apply),
220+
hamiltonian.local_kinetic_energy(network.apply,
221+
laplacian_method=laplacian),
213222
in_axes=(
214223
None,
215224
networks.FermiNetData(

ferminet/tests/train_test.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,23 @@ def _config_params():
5555
yield {'system': system,
5656
'optimizer': optimizer,
5757
'complex_': complex_,
58-
'states': states}
58+
'states': states,
59+
'laplacian': 'default'}
5960
for optimizer in ('kfac', 'adam', 'lamb', 'none'):
6061
yield {
6162
'system': 'H' if optimizer in ('kfac', 'adam') else 'Li',
6263
'optimizer': optimizer,
6364
'complex_': False,
6465
'states': 0,
66+
'laplacian': 'default',
67+
}
68+
for states, laplacian in itertools.product((0, 2), ('default', 'folx')):
69+
yield {
70+
'system': 'Li',
71+
'optimizer': 'kfac',
72+
'complex_': False,
73+
'states': states,
74+
'laplacian': laplacian
6575
}
6676

6777

@@ -75,7 +85,7 @@ def setUp(self):
7585
pyscf.lib.param.TMPDIR = None
7686

7787
@parameterized.parameters(_config_params())
78-
def test_training_step(self, system, optimizer, complex_, states):
88+
def test_training_step(self, system, optimizer, complex_, states, laplacian):
7989
if system in ('H', 'Li'):
8090
cfg = atom.get_config()
8191
cfg.system.atom = system
@@ -90,6 +100,7 @@ def test_training_step(self, system, optimizer, complex_, states):
90100
cfg.pretrain.iterations = 10
91101
cfg.mcmc.burn_in = 10
92102
cfg.optim.optimizer = optimizer
103+
cfg.optim.laplacian = laplacian
93104
cfg.optim.iterations = 3
94105
cfg.debug.check_nan = True
95106
cfg.observables.s2 = True

ferminet/train.py

+5
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,11 @@ def log_network(*args, **kwargs):
672672
blocks=cfg.mcmc.blocks * num_states,
673673
)
674674
# Construct loss and optimizer
675+
laplacian_method = cfg.optim.get('laplacian', 'default')
675676
if cfg.system.make_local_energy_fn:
677+
if laplacian_method != 'default':
678+
raise NotImplementedError(f'Laplacian method {laplacian_method}'
679+
'not yet supported by custom local energy fns.')
676680
local_energy_module, local_energy_fn = (
677681
cfg.system.make_local_energy_fn.rsplit('.', maxsplit=1))
678682
local_energy_module = importlib.import_module(local_energy_module)
@@ -692,6 +696,7 @@ def log_network(*args, **kwargs):
692696
nspins=nspins,
693697
use_scan=False,
694698
complex_output=cfg.network.get('complex', False),
699+
laplacian_method=laplacian_method,
695700
states=cfg.system.get('states', 0),
696701
pp_type=cfg.system.get('pp', {'type': 'ccecp'}).get('type'),
697702
pp_symbols=pp_symbols if cfg.system.get('use_pp') else None)

setup.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'attrs',
2525
'chex',
2626
'h5py',
27+
'folx @ git+https://github.com/microsoft/folx',
2728
'jax',
2829
'jaxlib',
2930
# TODO(b/230487443) - use released version of kfac.
@@ -49,7 +50,8 @@ def ferminet_test_suite():
4950
setup(
5051
name='ferminet',
5152
version='0.2',
52-
description='A library to train networks to represent ground state wavefunctions of fermionic systems',
53+
description=('A library to train networks to represent ground '
54+
'state wavefunctions of fermionic systems'),
5355
url='https://github.com/deepmind/ferminet',
5456
author='DeepMind',
5557
author_email='[email protected]',

0 commit comments

Comments
 (0)