Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2334205

Browse files
author
Daniel Zuegner
committedJan 22, 2025·
experimental apple silicon
1 parent d3be408 commit 2334205

File tree

11 files changed

+79
-64
lines changed

11 files changed

+79
-64
lines changed
 

‎README.md

+13-4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
MatterGen is a generative model for inorganic materials design across the periodic table that can be fine-tuned to steer the generation towards a wide range of property constraints.
1616

17+
> [!WARNING]
18+
> This branch adds **experimental** support to run MatterGen on Apple Silicon. Use at your own risk.
19+
1720
## Table of Contents
1821
- [Installation](#installation)
1922
- [Get started with a pre-trained model](#get-started-with-a-pre-trained-model)
@@ -74,6 +77,7 @@ To sample from the pre-trained base model, run the following command.
7477
```bash
7578
export MODEL_PATH=checkpoints/mattergen_base # Or provide your own model
7679
export RESULTS_PATH=results/ # Samples will be written to this directory
80+
export PYTORCH_ENABLE_MPS_FALLBACK=1
7781

7882
# generate batch_size * num_batches samples
7983
python scripts/generate.py $RESULTS_PATH $MODEL_PATH --batch_size=16 --num_batches 1
@@ -89,6 +93,7 @@ With a fine-tuned model, you can generate materials conditioned on a target prop
8993
For example, to sample from the model trained on magnetic density, you can run the following command.
9094
```bash
9195
export MODEL_NAME=dft_mag_density
96+
export PYTORCH_ENABLE_MPS_FALLBACK=1
9297
export MODEL_PATH="checkpoints/$MODEL_NAME" # Or provide your own model
9398
export RESULTS_PATH="results/$MODEL_NAME/" # Samples will be written to this directory, e.g., `results/dft_mag_density`
9499

@@ -102,6 +107,7 @@ python scripts/generate.py $RESULTS_PATH $MODEL_PATH --batch_size=16 --checkpoin
102107
You can also generate materials conditioned on more than one property. For instance, you can use the pre-trained model located at `checkpoints/chemical_system_energy_above_hull` to generate conditioned on chemical system and energy above the hull, or the model at `checkpoints/dft_mag_density_hhi_score` for joint conditioning on [HHI score](https://en.wikipedia.org/wiki/Herfindahl%E2%80%93Hirschman_index) and magnetic density.
103108
Adapt the following command to your specific needs:
104109
```bash
110+
export PYTORCH_ENABLE_MPS_FALLBACK=1
105111
export MODEL_NAME=chemical_system_energy_above_hull
106112
export MODEL_PATH="checkpoints/$MODEL_NAME" # Or provide your own model
107113
export RESULTS_PATH="results/$MODEL_NAME/" # Samples will be written to this directory, e.g., `results/dft_mag_density`
@@ -154,7 +160,8 @@ This will take some time (~1h). You will get preprocessed data files in `dataset
154160
You can train the MatterGen base model on `mp_20` using the following command.
155161

156162
```bash
157-
python scripts/run.py data_module=mp_20 ~trainer.logger
163+
export PYTORCH_ENABLE_MPS_FALLBACK=1
164+
python scripts/run.py data_module=mp_20 ~trainer.logger ~trainer.strategy trainer.accelerator=mps
158165
```
159166

160167
The validation loss (`loss_val`) should reach 0.4 after 360 epochs (about 80k steps). The output checkpoints can be found at `outputs/singlerun/${now:%Y-%m-%d}/${now:%H-%M-%S}`. We call this folder `$MODEL_PATH` for future reference.
@@ -166,7 +173,8 @@ The validation loss (`loss_val`) should reach 0.4 after 360 epochs (about 80k st
166173
167174
To train the MatterGen base model on `alex_mp_20`, use the following command:
168175
```bash
169-
python scripts/run.py data_module=alex_mp_20 ~trainer.logger trainer.accumulate_grad_batches=4
176+
export PYTORCH_ENABLE_MPS_FALLBACK=1
177+
python scripts/run.py data_module=alex_mp_20 ~trainer.logger trainer.accumulate_grad_batches=4 ~trainer.strategy trainer.accelerator=mps
170178
```
171179
> [!TIP]
172180
> Note that a single GPU's memory usually is not enough for the batch size of 512, hence we accumulate gradients over 4 batches. If you still run out of memory, increase this further.
@@ -184,7 +192,8 @@ Assume that you have a MatterGen base model at `$MODEL_PATH` (e.g., `checkpoints
184192
```bash
185193
export PROPERTY=dft_mag_density
186194
export MODEL_PATH=checkpoints/mattergen_base
187-
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/property_embeddings@adapter.adapter.property_embeddings_adapt.$PROPERTY=$PROPERTY ~trainer.logger data_module.properties=["$PROPERTY"]
195+
export PYTORCH_ENABLE_MPS_FALLBACK=1
196+
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/property_embeddings@adapter.adapter.property_embeddings_adapt.$PROPERTY=$PROPERTY ~trainer.logger data_module.properties=["$PROPERTY"] ~trainer.strategy trainer.accelerator=mps
188197
```
189198

190199
`dft_mag_density` denotes the target property for fine-tuning.
@@ -261,4 +270,4 @@ If you have any questions not covered here, please create an issue or contact th
261270
Materials Design team at [ai4s-materials@microsoft.com](mailto:ai4s-materials@microsoft.com).
262271

263272
We would appreciate your feedback and would like to know how MatterGen has been beneficial to your research.
264-
Please share your experiences with us at [ai4s-materials@microsoft.com](mailto:ai4s-materials@microsoft.com).
273+
Please share your experiences with us at [ai4s-materials@microsoft.com](mailto:ai4s-materials@microsoft.com).

‎mattergen/common/gemnet/gemnet.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
lattice_params_to_matrix_torch,
3535
radius_graph_pbc,
3636
)
37-
from mattergen.common.utils.globals import MODELS_PROJECT_ROOT
37+
from mattergen.common.utils.globals import MODELS_PROJECT_ROOT, get_device, get_pyg_device
3838
from mattergen.common.utils.lattice_score import edge_score_to_lattice_score_frac_symmetric
3939

4040

@@ -381,17 +381,19 @@ def get_triplets(
381381

382382
value = torch.arange(idx_s.size(0), device=idx_s.device, dtype=idx_s.dtype)
383383
# Possibly contains multiple copies of the same edge (for periodic interactions)
384+
pyg_device = get_pyg_device()
385+
torch_device = get_device()
384386
adj = SparseTensor(
385-
row=idx_t,
386-
col=idx_s,
387-
value=value,
388-
sparse_sizes=(num_atoms, num_atoms),
387+
row=idx_t.to(pyg_device),
388+
col=idx_s.to(pyg_device),
389+
value=value.to(pyg_device),
390+
sparse_sizes=(num_atoms.to(pyg_device), num_atoms.to(pyg_device)),
389391
)
390-
adj_edges = adj[idx_t]
392+
adj_edges = adj[idx_t.to(pyg_device)].to(torch_device)
391393

392394
# Edge indices (b->a, c->a) for triplets.
393-
id3_ba = adj_edges.storage.value()
394-
id3_ca = adj_edges.storage.row()
395+
id3_ba = adj_edges.storage.value().to(torch_device)
396+
id3_ca = adj_edges.storage.row().to(torch_device)
395397

396398
# Remove self-loop triplets
397399
# Compare edge indices, not atom indices to correctly handle periodic interactions
@@ -773,4 +775,4 @@ def forward(
773775

774776
@property
775777
def num_params(self):
776-
return sum(p.numel() for p in self.parameters())
778+
return sum(p.numel() for p in self.parameters())

‎mattergen/common/utils/data_utils.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -361,19 +361,7 @@ def compute_lattice_polar_decomposition(lattice_matrix: torch.Tensor) -> torch.T
361361
# lattice_matrix: [batch_size, 3, 3]
362362
# Computes the (unique) symmetric lattice matrix that is equivalent (up to rotation) to the input lattice.
363363

364-
if lattice_matrix.device.type == "cuda":
365-
# there is an issue running torch.linalg.svd on cuda tensors with driver version 450.*
366-
367-
try:
368-
W, S, V_transp = torch.linalg.svd(lattice_matrix)
369-
except torch._C._LinAlgError:
370-
# move to cpu and try again
371-
W, S, V_transp = torch.linalg.svd(lattice_matrix.to("cpu"))
372-
W = W.to(lattice_matrix.device.type)
373-
S = S.to(lattice_matrix.device.type)
374-
V_transp = V_transp.to(lattice_matrix.device.type)
375-
else:
376-
W, S, V_transp = torch.linalg.svd(lattice_matrix)
364+
W, S, V_transp = torch.linalg.svd(lattice_matrix)
377365
S_square = torch.diag_embed(S)
378366
V = V_transp.transpose(1, 2)
379367
U = W @ V_transp

‎mattergen/common/utils/eval_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
GENERATED_CRYSTALS_ZIP_FILE_NAME,
2121
)
2222
from mattergen.common.utils.data_classes import MatterGenCheckpointInfo
23+
from mattergen.common.utils.globals import get_device
2324
from mattergen.diffusion.lightning_module import DiffusionLightningModule
2425

2526
# logging
@@ -54,7 +55,7 @@ def load_model_diffusion(
5455
try:
5556
model, incompatible_keys = DiffusionLightningModule.load_from_checkpoint_and_config(
5657
ckpt,
57-
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
58+
map_location=get_device(),
5859
config=cfg.lightning_module,
5960
strict=args.strict_checkpoint_loading,
6061
)

‎mattergen/common/utils/ocp_graph_utils.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import torch
1313
from torch_scatter import segment_coo, segment_csr
1414

15+
from mattergen.common.utils.globals import get_pyg_device
16+
1517

1618
def get_pbc_distances(
1719
pos: torch.Tensor,
@@ -272,14 +274,20 @@ def get_max_neighbors_mask(
272274
# Get number of neighbors
273275
# segment_coo assumes sorted index
274276
ones = index.new_ones(1).expand_as(index)
275-
num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
277+
pyg_device = get_pyg_device()
278+
device_before = ones.device
279+
num_neighbors = segment_coo(ones.to(pyg_device), index.to(pyg_device), dim_size=num_atoms).to(
280+
device_before
281+
)
276282
max_num_neighbors = num_neighbors.max()
277283
num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold)
278284

279285
# Get number of (thresholded) neighbors per image
280286
image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
281287
image_indptr[1:] = torch.cumsum(natoms, dim=0)
282-
num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
288+
num_neighbors_image = segment_csr(
289+
num_neighbors_thresholded.to(pyg_device), image_indptr.to(pyg_device)
290+
).to(device_before)
283291

284292
# If max_num_neighbors is below the threshold, return early
285293
if max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0:

‎mattergen/evaluation/evaluate.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pymatgen.core.structure import Structure
55

6+
from mattergen.common.utils.globals import get_device
67
from mattergen.evaluation.metrics.evaluator import MetricsEvaluator
78
from mattergen.evaluation.reference.reference_dataset import ReferenceDataset
89
from mattergen.evaluation.utils.relaxation import relax_structures
@@ -18,11 +19,12 @@ def evaluate(
1819
relax: bool = True,
1920
energies: list[float] | None = None,
2021
reference: ReferenceDataset | None = None,
21-
structure_matcher: OrderedStructureMatcher
22-
| DisorderedStructureMatcher = DefaultDisorderedStructureMatcher(),
22+
structure_matcher: (
23+
OrderedStructureMatcher | DisorderedStructureMatcher
24+
) = DefaultDisorderedStructureMatcher(),
2325
save_as: str | None = None,
2426
potential_load_path: str | None = None,
25-
device: str = "cuda",
27+
device: str = str(get_device()),
2628
) -> dict[str, float | int]:
2729
"""Evaluate the structures against a reference dataset.
2830
@@ -41,18 +43,20 @@ def evaluate(
4143
if relax and energies is not None:
4244
raise ValueError("Cannot accept energies if relax is True.")
4345
if relax:
44-
relaxed_structures, energies = relax_structures(structures, device=device, load_path=potential_load_path)
46+
relaxed_structures, energies = relax_structures(
47+
structures, device=device, load_path=potential_load_path
48+
)
4549
else:
4650
relaxed_structures = structures
4751
evaluator = MetricsEvaluator.from_structures_and_energies(
48-
structures=relaxed_structures,
52+
structures=relaxed_structures,
4953
energies=energies,
5054
original_structures=structures,
5155
reference=reference,
52-
structure_matcher=structure_matcher
56+
structure_matcher=structure_matcher,
5357
)
5458
return evaluator.compute_metrics(
55-
metrics = evaluator.available_metrics,
56-
save_as = save_as,
57-
pretty_print = True,
59+
metrics=evaluator.available_metrics,
60+
save_as=save_as,
61+
pretty_print=True,
5862
)

‎mattergen/evaluation/utils/relaxation.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,31 @@
99
from pymatgen.core import Structure
1010
from pymatgen.io.ase import AseAtomsAdaptor
1111

12+
from mattergen.common.utils.globals import get_device
13+
1214
logger = get_logger()
1315
logger.level("ERROR")
1416

15-
def relax_atoms(atoms: list[Atoms], device: str = "cuda", load_path: str = None, **kwargs) -> tuple[list[Atoms], np.ndarray]:
16-
potential = Potential.from_checkpoint(device=device, load_path=load_path, load_training_state=False)
17+
18+
def relax_atoms(
19+
atoms: list[Atoms], device: str = str(get_device()), load_path: str = None, **kwargs
20+
) -> tuple[list[Atoms], np.ndarray]:
21+
potential = Potential.from_checkpoint(
22+
device=device, load_path=load_path, load_training_state=False
23+
)
1724
batch_relaxer = BatchRelaxer(potential=potential, filter="EXPCELLFILTER", **kwargs)
1825
relaxation_trajectories = batch_relaxer.relax(atoms)
1926
relaxed_atoms = [t[-1] for t in relaxation_trajectories.values()]
20-
total_energies = np.array([a.info['total_energy'] for a in relaxed_atoms])
27+
total_energies = np.array([a.info["total_energy"] for a in relaxed_atoms])
2128
return relaxed_atoms, total_energies
2229

2330

24-
def relax_structures(structures: Structure | list[Structure], device: str = "cuda", load_path: str = None, **kwargs) -> tuple[list[Structure], np.ndarray]:
31+
def relax_structures(
32+
structures: Structure | list[Structure],
33+
device: str = str(get_device()),
34+
load_path: str = None,
35+
**kwargs
36+
) -> tuple[list[Structure], np.ndarray]:
2537
if isinstance(structures, Structure):
2638
structures = [structures]
2739
atoms = [AseAtomsAdaptor.get_atoms(s) for s in structures]

‎mattergen/generator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
make_structure,
3030
save_structures,
3131
)
32-
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH
32+
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device
3333
from mattergen.diffusion.lightning_module import DiffusionLightningModule
3434
from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector
3535

@@ -332,7 +332,7 @@ def prepare(self) -> None:
332332
if self._model is not None:
333333
return
334334
model = load_model_diffusion(self.checkpoint_info)
335-
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
335+
model = model.to(get_device())
336336
self._model = model
337337
self._cfg = self.checkpoint_info.config
338338

‎pyproject.toml

+7-17
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ dependencies = [
5252
"setuptools",
5353
"SMACT",
5454
"sympy>=1.11.1",
55-
"torch==2.2.1+cu118",
56-
"torchvision==0.17.1+cu118",
57-
"torchaudio==2.2.1+cu118",
55+
"torch==2.4.1",
56+
"torchvision==0.19.1",
57+
"torchaudio==2.4.1",
5858
"torch_cluster",
5959
"torch_geometric>=2.5",
6060
"torch_scatter",
@@ -67,17 +67,7 @@ dependencies = [
6767
include = ["mattergen*"]
6868

6969
[tool.uv.sources]
70-
torch = { index = "pytorch" }
71-
torchvision = { index = "pytorch" }
72-
torchaudio = { index = "pytorch" }
73-
pyg-lib = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/pyg_lib-0.4.0%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
74-
torch_cluster = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_cluster-1.6.3%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
75-
torch_scatter = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_scatter-2.1.2%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
76-
torch_sparse = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_sparse-0.6.18%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
77-
78-
[[tool.uv.index]]
79-
name = "pytorch"
80-
url = "https://download.pytorch.org/whl/cu118"
81-
explicit = true
82-
83-
70+
pyg-lib = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/pyg_lib-0.4.0%2Bpt24-cp310-cp310-macosx_14_0_universal2.whl" }
71+
torch_cluster = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_cluster-1.6.3-cp310-cp310-macosx_10_9_universal2.whl" }
72+
torch_sparse = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_sparse-0.6.18-cp310-cp310-macosx_11_0_universal2.whl" }
73+
torch_scatter = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_scatter-2.1.2-cp310-cp310-macosx_10_9_universal2.whl" }

‎scripts/evaluate.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010

1111
from mattergen.common.utils.eval_utils import load_structures
12+
from mattergen.common.utils.globals import get_device
1213
from mattergen.evaluation.evaluate import evaluate
1314
from mattergen.evaluation.utils.structure_matcher import (
1415
DefaultDisorderedStructureMatcher,
@@ -25,7 +26,7 @@ def main(
2526
potential_load_path: (
2627
Literal["MatterSim-v1.0.0-1M.pth", "MatterSim-v1.0.0-5M.pth"] | None
2728
) = None,
28-
device: str = "cuda",
29+
device: str = str(get_device()),
2930
):
3031
structures = load_structures(Path(structures_path))
3132
energies = np.load(energies_path) if energies_path else None

‎scripts/finetune.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytorch_lightning.cli import SaveConfigCallback
1717

1818
from mattergen.common.utils.data_classes import MatterGenCheckpointInfo
19-
from mattergen.common.utils.globals import MODELS_PROJECT_ROOT
19+
from mattergen.common.utils.globals import MODELS_PROJECT_ROOT, get_device
2020
from mattergen.diffusion.run import AddConfigCallback, SimpleParser, maybe_instantiate
2121

2222
logger = logging.getLogger(__name__)
@@ -81,7 +81,7 @@ def init_adapter_lightningmodule_from_pretrained(
8181

8282
lightning_module = hydra.utils.instantiate(lightning_module_cfg)
8383

84-
ckpt: dict = torch.load(ckpt_path)
84+
ckpt: dict = torch.load(ckpt_path, map_location=get_device())
8585
pretrained_dict: OrderedDict = ckpt["state_dict"]
8686
scratch_dict: OrderedDict = lightning_module.state_dict()
8787
scratch_dict.update(

0 commit comments

Comments
 (0)
Please sign in to comment.