Skip to content

Commit

Permalink
Add experimental Apple Silicon support (#35)
Browse files Browse the repository at this point in the history
Adds experimental Apple Silicon support.

---------

Co-authored-by: Daniel Zuegner <[email protected]>
  • Loading branch information
danielzuegner and Daniel Zuegner authored Jan 22, 2025
1 parent d3be408 commit 826d109
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 65 deletions.
40 changes: 35 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

MatterGen is a generative model for inorganic materials design across the periodic table that can be fine-tuned to steer the generation towards a wide range of property constraints.


## Table of Contents
- [Installation](#installation)
- [Get started with a pre-trained model](#get-started-with-a-pre-trained-model)
Expand All @@ -28,9 +29,10 @@ MatterGen is a generative model for inorganic materials design across the period

## Installation


The easiest way to install prerequisites is via [uv](https://docs.astral.sh/uv/), a fast Python package and project manager.

The MatterGen environment can be installed via the following command:
The MatterGen environment can be installed via the following command (assumes you are running Linux and have a CUDA GPU):
```bash
pip install uv
uv venv .venv --python 3.10
Expand All @@ -44,6 +46,24 @@ git lfs --version
```
If this prints some version like `git-lfs/3.0.2 (GitHub; linux amd64; go 1.18.1)`, you can skip the following step.

### Apple Silicon
> [!WARNING]
> Running MatterGen on Apple Silicon is **experimental**. Use at your own risk.
> Further, you need to run `export PYTORCH_ENABLE_MPS_FALLBACK=1` before any training or generation run.
To install the environment for Apple Silicon, run these commands:
```bash
cp pyproject.toml pyproject.linux.toml
mv pyproject_apple_silicon.toml pyproject.toml
pip install uv
uv venv .venv --python 3.10
source .venv/bin/activate
uv pip install -e .
export PYTORCH_ENABLE_MPS_FALLBACK=1 # required to run MatterGen on Apple Silicon
```



### Install Git LFS
If Git LFS was not installed before you cloned this repo, you can install it and download the missing files via:
```bash
Expand Down Expand Up @@ -156,6 +176,8 @@ You can train the MatterGen base model on `mp_20` using the following command.
```bash
python scripts/run.py data_module=mp_20 ~trainer.logger
```
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.
The validation loss (`loss_val`) should reach 0.4 after 360 epochs (about 80k steps). The output checkpoints can be found at `outputs/singlerun/${now:%Y-%m-%d}/${now:%H-%M-%S}`. We call this folder `$MODEL_PATH` for future reference.
> [!NOTE]
Expand All @@ -168,6 +190,9 @@ To train the MatterGen base model on `alex_mp_20`, use the following command:
```bash
python scripts/run.py data_module=alex_mp_20 ~trainer.logger trainer.accumulate_grad_batches=4
```
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.
> [!TIP]
> Note that a single GPU's memory usually is not enough for the batch size of 512, hence we accumulate gradients over 4 batches. If you still run out of memory, increase this further.
Expand All @@ -186,8 +211,10 @@ export PROPERTY=dft_mag_density
export MODEL_PATH=checkpoints/mattergen_base
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY=$PROPERTY ~trainer.logger data_module.properties=["$PROPERTY"]
```

`dft_mag_density` denotes the target property for fine-tuning.
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.

> [!TIP]
> You can select any property that is available in the dataset. See [`mattergen/conf/data_module/mp_20.yaml`](mattergen/conf/data_module/mp_20.yaml) or [`mattergen/conf/data_module/alex_mp_20.yaml`](mattergen/conf/data_module/alex_mp_20.yaml) for the list of supported properties. You can also add your own custom property data. See [below](#fine-tune-on-your-own-property-data) for instructions.
Expand All @@ -199,12 +226,15 @@ You can also fine-tune MatterGen on multiple properties. For instance, to fine-t
export PROPERTY1=dft_mag_density
export PROPERTY2=dft_band_gap
export MODEL_PATH=checkpoints/mattergen_base
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY1=$PROPERTY1 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY2=$PROPERTY2 ~trainer.logger data_module.properties=["$PROPERTY1", "$PROPERTY2"]
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY1=$PROPERTY1 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY2=$PROPERTY2 ~trainer.logger data_module.properties=["$PROPERTY1","$PROPERTY2"]
```
> [!TIP]
> Add more properties analogously by adding these overrides:
> 1. `+lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.<my_property>=<my_property>`
> 2. Add `<my_property>` to the `data_module.properties=["$PROPERTY1", "$PROPERTY2", ..., <my_property>]` override.
> 2. Add `<my_property>` to the `data_module.properties=["$PROPERTY1","$PROPERTY2",...,<my_property>]` override.
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.
#### Fine-tune on your own property data
You may also fine-tune MatterGen on your own property data. Essentially what you need is a property value (typically `float`) for a subset of the data you want to train on (e.g., `alex_mp_20`). Proceed as follows:
Expand Down Expand Up @@ -261,4 +291,4 @@ If you have any questions not covered here, please create an issue or contact th
Materials Design team at [[email protected]](mailto:[email protected]).

We would appreciate your feedback and would like to know how MatterGen has been beneficial to your research.
Please share your experiences with us at [[email protected]](mailto:[email protected]).
Please share your experiences with us at [[email protected]](mailto:[email protected]).
20 changes: 11 additions & 9 deletions mattergen/common/gemnet/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
lattice_params_to_matrix_torch,
radius_graph_pbc,
)
from mattergen.common.utils.globals import MODELS_PROJECT_ROOT
from mattergen.common.utils.globals import MODELS_PROJECT_ROOT, get_device, get_pyg_device
from mattergen.common.utils.lattice_score import edge_score_to_lattice_score_frac_symmetric


Expand Down Expand Up @@ -381,17 +381,19 @@ def get_triplets(

value = torch.arange(idx_s.size(0), device=idx_s.device, dtype=idx_s.dtype)
# Possibly contains multiple copies of the same edge (for periodic interactions)
pyg_device = get_pyg_device()
torch_device = get_device()
adj = SparseTensor(
row=idx_t,
col=idx_s,
value=value,
sparse_sizes=(num_atoms, num_atoms),
row=idx_t.to(pyg_device),
col=idx_s.to(pyg_device),
value=value.to(pyg_device),
sparse_sizes=(num_atoms.to(pyg_device), num_atoms.to(pyg_device)),
)
adj_edges = adj[idx_t]
adj_edges = adj[idx_t.to(pyg_device)].to(torch_device)

# Edge indices (b->a, c->a) for triplets.
id3_ba = adj_edges.storage.value()
id3_ca = adj_edges.storage.row()
id3_ba = adj_edges.storage.value().to(torch_device)
id3_ca = adj_edges.storage.row().to(torch_device)

# Remove self-loop triplets
# Compare edge indices, not atom indices to correctly handle periodic interactions
Expand Down Expand Up @@ -773,4 +775,4 @@ def forward(

@property
def num_params(self):
return sum(p.numel() for p in self.parameters())
return sum(p.numel() for p in self.parameters())
14 changes: 1 addition & 13 deletions mattergen/common/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,7 @@ def compute_lattice_polar_decomposition(lattice_matrix: torch.Tensor) -> torch.T
# lattice_matrix: [batch_size, 3, 3]
# Computes the (unique) symmetric lattice matrix that is equivalent (up to rotation) to the input lattice.

if lattice_matrix.device.type == "cuda":
# there is an issue running torch.linalg.svd on cuda tensors with driver version 450.*

try:
W, S, V_transp = torch.linalg.svd(lattice_matrix)
except torch._C._LinAlgError:
# move to cpu and try again
W, S, V_transp = torch.linalg.svd(lattice_matrix.to("cpu"))
W = W.to(lattice_matrix.device.type)
S = S.to(lattice_matrix.device.type)
V_transp = V_transp.to(lattice_matrix.device.type)
else:
W, S, V_transp = torch.linalg.svd(lattice_matrix)
W, S, V_transp = torch.linalg.svd(lattice_matrix)
S_square = torch.diag_embed(S)
V = V_transp.transpose(1, 2)
U = W @ V_transp
Expand Down
3 changes: 2 additions & 1 deletion mattergen/common/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GENERATED_CRYSTALS_ZIP_FILE_NAME,
)
from mattergen.common.utils.data_classes import MatterGenCheckpointInfo
from mattergen.common.utils.globals import get_device
from mattergen.diffusion.lightning_module import DiffusionLightningModule

# logging
Expand Down Expand Up @@ -54,7 +55,7 @@ def load_model_diffusion(
try:
model, incompatible_keys = DiffusionLightningModule.load_from_checkpoint_and_config(
ckpt,
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
map_location=get_device(),
config=cfg.lightning_module,
strict=args.strict_checkpoint_loading,
)
Expand Down
24 changes: 24 additions & 0 deletions mattergen/common/utils/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,32 @@
2. It registers a new resolver for OmegaConf, `eval`, which allows us to use `eval` in our config files.
"""
import os
from functools import lru_cache
from pathlib import Path

import torch
from omegaconf import OmegaConf


@lru_cache
def get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")


@lru_cache
def get_pyg_device() -> torch.device:
"""
Some operations of pyg don't work on MPS, so fall back to CPU.
"""
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")


MODELS_PROJECT_ROOT = Path(__file__).resolve().parents[2]
print(f"MODELS_PROJECT_ROOT: {MODELS_PROJECT_ROOT}")

Expand Down Expand Up @@ -109,6 +131,8 @@
83,
]
MAX_ATOMIC_NUM = 100


# Set `eval` resolver
def try_eval(s):
"""This is a custom resolver for OmegaConf that allows us to use `eval` in our config files
Expand Down
13 changes: 11 additions & 2 deletions mattergen/common/utils/ocp_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
from torch_scatter import segment_coo, segment_csr

from mattergen.common.utils.globals import get_pyg_device


def get_pbc_distances(
pos: torch.Tensor,
Expand Down Expand Up @@ -272,14 +274,21 @@ def get_max_neighbors_mask(
# Get number of neighbors
# segment_coo assumes sorted index
ones = index.new_ones(1).expand_as(index)
num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
# required because PyG does not support MPS for the segment_coo operation yet.
pyg_device = get_pyg_device()
device_before = ones.device
num_neighbors = segment_coo(ones.to(pyg_device), index.to(pyg_device), dim_size=num_atoms).to(
device_before
)
max_num_neighbors = num_neighbors.max()
num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold)

# Get number of (thresholded) neighbors per image
image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
image_indptr[1:] = torch.cumsum(natoms, dim=0)
num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
num_neighbors_image = segment_csr(
num_neighbors_thresholded.to(pyg_device), image_indptr.to(pyg_device)
).to(device_before)

# If max_num_neighbors is below the threshold, return early
if max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0:
Expand Down
22 changes: 13 additions & 9 deletions mattergen/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pymatgen.core.structure import Structure

from mattergen.common.utils.globals import get_device
from mattergen.evaluation.metrics.evaluator import MetricsEvaluator
from mattergen.evaluation.reference.reference_dataset import ReferenceDataset
from mattergen.evaluation.utils.relaxation import relax_structures
Expand All @@ -18,11 +19,12 @@ def evaluate(
relax: bool = True,
energies: list[float] | None = None,
reference: ReferenceDataset | None = None,
structure_matcher: OrderedStructureMatcher
| DisorderedStructureMatcher = DefaultDisorderedStructureMatcher(),
structure_matcher: (
OrderedStructureMatcher | DisorderedStructureMatcher
) = DefaultDisorderedStructureMatcher(),
save_as: str | None = None,
potential_load_path: str | None = None,
device: str = "cuda",
device: str = str(get_device()),
) -> dict[str, float | int]:
"""Evaluate the structures against a reference dataset.
Expand All @@ -41,18 +43,20 @@ def evaluate(
if relax and energies is not None:
raise ValueError("Cannot accept energies if relax is True.")
if relax:
relaxed_structures, energies = relax_structures(structures, device=device, load_path=potential_load_path)
relaxed_structures, energies = relax_structures(
structures, device=device, load_path=potential_load_path
)
else:
relaxed_structures = structures
evaluator = MetricsEvaluator.from_structures_and_energies(
structures=relaxed_structures,
structures=relaxed_structures,
energies=energies,
original_structures=structures,
reference=reference,
structure_matcher=structure_matcher
structure_matcher=structure_matcher,
)
return evaluator.compute_metrics(
metrics = evaluator.available_metrics,
save_as = save_as,
pretty_print = True,
metrics=evaluator.available_metrics,
save_as=save_as,
pretty_print=True,
)
20 changes: 16 additions & 4 deletions mattergen/evaluation/utils/relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,31 @@
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor

from mattergen.common.utils.globals import get_device

logger = get_logger()
logger.level("ERROR")

def relax_atoms(atoms: list[Atoms], device: str = "cuda", load_path: str = None, **kwargs) -> tuple[list[Atoms], np.ndarray]:
potential = Potential.from_checkpoint(device=device, load_path=load_path, load_training_state=False)

def relax_atoms(
atoms: list[Atoms], device: str = str(get_device()), load_path: str = None, **kwargs
) -> tuple[list[Atoms], np.ndarray]:
potential = Potential.from_checkpoint(
device=device, load_path=load_path, load_training_state=False
)
batch_relaxer = BatchRelaxer(potential=potential, filter="EXPCELLFILTER", **kwargs)
relaxation_trajectories = batch_relaxer.relax(atoms)
relaxed_atoms = [t[-1] for t in relaxation_trajectories.values()]
total_energies = np.array([a.info['total_energy'] for a in relaxed_atoms])
total_energies = np.array([a.info["total_energy"] for a in relaxed_atoms])
return relaxed_atoms, total_energies


def relax_structures(structures: Structure | list[Structure], device: str = "cuda", load_path: str = None, **kwargs) -> tuple[list[Structure], np.ndarray]:
def relax_structures(
structures: Structure | list[Structure],
device: str = str(get_device()),
load_path: str = None,
**kwargs
) -> tuple[list[Structure], np.ndarray]:
if isinstance(structures, Structure):
structures = [structures]
atoms = [AseAtomsAdaptor.get_atoms(s) for s in structures]
Expand Down
4 changes: 2 additions & 2 deletions mattergen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
make_structure,
save_structures,
)
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device
from mattergen.diffusion.lightning_module import DiffusionLightningModule
from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector

Expand Down Expand Up @@ -332,7 +332,7 @@ def prepare(self) -> None:
if self._model is not None:
return
model = load_model_diffusion(self.checkpoint_info)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(get_device())
self._model = model
self._cfg = self.checkpoint_info.config

Expand Down
24 changes: 7 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ dependencies = [
"setuptools",
"SMACT",
"sympy>=1.11.1",
"torch==2.2.1+cu118",
"torchvision==0.17.1+cu118",
"torchaudio==2.2.1+cu118",
"torch==2.4.1",
"torchvision==0.19.1",
"torchaudio==2.4.1",
"torch_cluster",
"torch_geometric>=2.5",
"torch_scatter",
Expand All @@ -67,17 +67,7 @@ dependencies = [
include = ["mattergen*"]

[tool.uv.sources]
torch = { index = "pytorch" }
torchvision = { index = "pytorch" }
torchaudio = { index = "pytorch" }
pyg-lib = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/pyg_lib-0.4.0%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
torch_cluster = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_cluster-1.6.3%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
torch_scatter = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_scatter-2.1.2%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
torch_sparse = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_sparse-0.6.18%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }

[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu118"
explicit = true


pyg-lib = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/pyg_lib-0.4.0%2Bpt24-cp310-cp310-macosx_14_0_universal2.whl" }
torch_cluster = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_cluster-1.6.3-cp310-cp310-macosx_10_9_universal2.whl" }
torch_sparse = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_sparse-0.6.18-cp310-cp310-macosx_11_0_universal2.whl" }
torch_scatter = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_scatter-2.1.2-cp310-cp310-macosx_10_9_universal2.whl" }
Loading

0 comments on commit 826d109

Please sign in to comment.