Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental Apple Silicon support #35

Merged
merged 7 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

MatterGen is a generative model for inorganic materials design across the periodic table that can be fine-tuned to steer the generation towards a wide range of property constraints.


## Table of Contents
- [Installation](#installation)
- [Get started with a pre-trained model](#get-started-with-a-pre-trained-model)
Expand All @@ -28,9 +29,10 @@ MatterGen is a generative model for inorganic materials design across the period

## Installation


The easiest way to install prerequisites is via [uv](https://docs.astral.sh/uv/), a fast Python package and project manager.

The MatterGen environment can be installed via the following command:
The MatterGen environment can be installed via the following command (assumes you are running Linux and have a CUDA GPU):
```bash
pip install uv
uv venv .venv --python 3.10
Expand All @@ -44,6 +46,24 @@ git lfs --version
```
If this prints some version like `git-lfs/3.0.2 (GitHub; linux amd64; go 1.18.1)`, you can skip the following step.

### Apple Silicon
> [!WARNING]
> Running MatterGen on Apple Silicon is **experimental**. Use at your own risk.
> Further, you need to run `export PYTORCH_ENABLE_MPS_FALLBACK=1` before any training or generation run.

To install the environment for Apple Silicon, run these commands:
```bash
cp pyproject.toml pyproject.linux.toml
mv pyproject_apple_silicon.toml pyproject.toml
pip install uv
uv venv .venv --python 3.10
source .venv/bin/activate
uv pip install -e .
export PYTORCH_ENABLE_MPS_FALLBACK=1 # required to run MatterGen on Apple Silicon
```



### Install Git LFS
If Git LFS was not installed before you cloned this repo, you can install it and download the missing files via:
```bash
Expand Down Expand Up @@ -156,6 +176,8 @@ You can train the MatterGen base model on `mp_20` using the following command.
```bash
python scripts/run.py data_module=mp_20 ~trainer.logger
```
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.

The validation loss (`loss_val`) should reach 0.4 after 360 epochs (about 80k steps). The output checkpoints can be found at `outputs/singlerun/${now:%Y-%m-%d}/${now:%H-%M-%S}`. We call this folder `$MODEL_PATH` for future reference.
> [!NOTE]
Expand All @@ -168,6 +190,9 @@ To train the MatterGen base model on `alex_mp_20`, use the following command:
```bash
python scripts/run.py data_module=alex_mp_20 ~trainer.logger trainer.accumulate_grad_batches=4
```
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.

> [!TIP]
> Note that a single GPU's memory usually is not enough for the batch size of 512, hence we accumulate gradients over 4 batches. If you still run out of memory, increase this further.

Expand All @@ -186,8 +211,10 @@ export PROPERTY=dft_mag_density
export MODEL_PATH=checkpoints/mattergen_base
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY=$PROPERTY ~trainer.logger data_module.properties=["$PROPERTY"]
```

`dft_mag_density` denotes the target property for fine-tuning.
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.


> [!TIP]
> You can select any property that is available in the dataset. See [`mattergen/conf/data_module/mp_20.yaml`](mattergen/conf/data_module/mp_20.yaml) or [`mattergen/conf/data_module/alex_mp_20.yaml`](mattergen/conf/data_module/alex_mp_20.yaml) for the list of supported properties. You can also add your own custom property data. See [below](#fine-tune-on-your-own-property-data) for instructions.
Expand All @@ -199,12 +226,15 @@ You can also fine-tune MatterGen on multiple properties. For instance, to fine-t
export PROPERTY1=dft_mag_density
export PROPERTY2=dft_band_gap
export MODEL_PATH=checkpoints/mattergen_base
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY1=$PROPERTY1 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY2=$PROPERTY2 ~trainer.logger data_module.properties=["$PROPERTY1", "$PROPERTY2"]
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY1=$PROPERTY1 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY2=$PROPERTY2 ~trainer.logger data_module.properties=["$PROPERTY1","$PROPERTY2"]
```
> [!TIP]
> Add more properties analogously by adding these overrides:
> 1. `+lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.<my_property>=<my_property>`
> 2. Add `<my_property>` to the `data_module.properties=["$PROPERTY1", "$PROPERTY2", ..., <my_property>]` override.
> 2. Add `<my_property>` to the `data_module.properties=["$PROPERTY1","$PROPERTY2",...,<my_property>]` override.

> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.

#### Fine-tune on your own property data
You may also fine-tune MatterGen on your own property data. Essentially what you need is a property value (typically `float`) for a subset of the data you want to train on (e.g., `alex_mp_20`). Proceed as follows:
Expand Down Expand Up @@ -261,4 +291,4 @@ If you have any questions not covered here, please create an issue or contact th
Materials Design team at [[email protected]](mailto:[email protected]).

We would appreciate your feedback and would like to know how MatterGen has been beneficial to your research.
Please share your experiences with us at [[email protected]](mailto:[email protected]).
Please share your experiences with us at [[email protected]](mailto:[email protected]).
20 changes: 11 additions & 9 deletions mattergen/common/gemnet/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
lattice_params_to_matrix_torch,
radius_graph_pbc,
)
from mattergen.common.utils.globals import MODELS_PROJECT_ROOT
from mattergen.common.utils.globals import MODELS_PROJECT_ROOT, get_device, get_pyg_device
from mattergen.common.utils.lattice_score import edge_score_to_lattice_score_frac_symmetric


Expand Down Expand Up @@ -381,17 +381,19 @@ def get_triplets(

value = torch.arange(idx_s.size(0), device=idx_s.device, dtype=idx_s.dtype)
# Possibly contains multiple copies of the same edge (for periodic interactions)
pyg_device = get_pyg_device()
torch_device = get_device()
adj = SparseTensor(
row=idx_t,
col=idx_s,
value=value,
sparse_sizes=(num_atoms, num_atoms),
row=idx_t.to(pyg_device),
col=idx_s.to(pyg_device),
value=value.to(pyg_device),
sparse_sizes=(num_atoms.to(pyg_device), num_atoms.to(pyg_device)),
)
adj_edges = adj[idx_t]
adj_edges = adj[idx_t.to(pyg_device)].to(torch_device)

# Edge indices (b->a, c->a) for triplets.
id3_ba = adj_edges.storage.value()
id3_ca = adj_edges.storage.row()
id3_ba = adj_edges.storage.value().to(torch_device)
id3_ca = adj_edges.storage.row().to(torch_device)

# Remove self-loop triplets
# Compare edge indices, not atom indices to correctly handle periodic interactions
Expand Down Expand Up @@ -773,4 +775,4 @@ def forward(

@property
def num_params(self):
return sum(p.numel() for p in self.parameters())
return sum(p.numel() for p in self.parameters())
14 changes: 1 addition & 13 deletions mattergen/common/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,7 @@ def compute_lattice_polar_decomposition(lattice_matrix: torch.Tensor) -> torch.T
# lattice_matrix: [batch_size, 3, 3]
# Computes the (unique) symmetric lattice matrix that is equivalent (up to rotation) to the input lattice.

if lattice_matrix.device.type == "cuda":
# there is an issue running torch.linalg.svd on cuda tensors with driver version 450.*

try:
W, S, V_transp = torch.linalg.svd(lattice_matrix)
except torch._C._LinAlgError:
# move to cpu and try again
W, S, V_transp = torch.linalg.svd(lattice_matrix.to("cpu"))
W = W.to(lattice_matrix.device.type)
S = S.to(lattice_matrix.device.type)
V_transp = V_transp.to(lattice_matrix.device.type)
else:
W, S, V_transp = torch.linalg.svd(lattice_matrix)
W, S, V_transp = torch.linalg.svd(lattice_matrix)
S_square = torch.diag_embed(S)
V = V_transp.transpose(1, 2)
U = W @ V_transp
Expand Down
3 changes: 2 additions & 1 deletion mattergen/common/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GENERATED_CRYSTALS_ZIP_FILE_NAME,
)
from mattergen.common.utils.data_classes import MatterGenCheckpointInfo
from mattergen.common.utils.globals import get_device
from mattergen.diffusion.lightning_module import DiffusionLightningModule

# logging
Expand Down Expand Up @@ -54,7 +55,7 @@ def load_model_diffusion(
try:
model, incompatible_keys = DiffusionLightningModule.load_from_checkpoint_and_config(
ckpt,
map_location=torch.device("cpu") if not torch.cuda.is_available() else None,
map_location=get_device(),
config=cfg.lightning_module,
strict=args.strict_checkpoint_loading,
)
Expand Down
24 changes: 24 additions & 0 deletions mattergen/common/utils/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,32 @@
2. It registers a new resolver for OmegaConf, `eval`, which allows us to use `eval` in our config files.
"""
import os
from functools import lru_cache
from pathlib import Path

import torch
from omegaconf import OmegaConf


@lru_cache
def get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")


@lru_cache
def get_pyg_device() -> torch.device:
"""
Some operations of pyg don't work on MPS, so fall back to CPU.
"""
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")


MODELS_PROJECT_ROOT = Path(__file__).resolve().parents[2]
print(f"MODELS_PROJECT_ROOT: {MODELS_PROJECT_ROOT}")

Expand Down Expand Up @@ -109,6 +131,8 @@
83,
]
MAX_ATOMIC_NUM = 100


# Set `eval` resolver
def try_eval(s):
"""This is a custom resolver for OmegaConf that allows us to use `eval` in our config files
Expand Down
13 changes: 11 additions & 2 deletions mattergen/common/utils/ocp_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch
from torch_scatter import segment_coo, segment_csr

from mattergen.common.utils.globals import get_pyg_device


def get_pbc_distances(
pos: torch.Tensor,
Expand Down Expand Up @@ -272,14 +274,21 @@ def get_max_neighbors_mask(
# Get number of neighbors
# segment_coo assumes sorted index
ones = index.new_ones(1).expand_as(index)
num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
# required because PyG does not support MPS for the segment_coo operation yet.
pyg_device = get_pyg_device()
device_before = ones.device
num_neighbors = segment_coo(ones.to(pyg_device), index.to(pyg_device), dim_size=num_atoms).to(
device_before
)
max_num_neighbors = num_neighbors.max()
num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold)

# Get number of (thresholded) neighbors per image
image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
image_indptr[1:] = torch.cumsum(natoms, dim=0)
num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
num_neighbors_image = segment_csr(
num_neighbors_thresholded.to(pyg_device), image_indptr.to(pyg_device)
).to(device_before)

# If max_num_neighbors is below the threshold, return early
if max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0:
Expand Down
22 changes: 13 additions & 9 deletions mattergen/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pymatgen.core.structure import Structure

from mattergen.common.utils.globals import get_device
from mattergen.evaluation.metrics.evaluator import MetricsEvaluator
from mattergen.evaluation.reference.reference_dataset import ReferenceDataset
from mattergen.evaluation.utils.relaxation import relax_structures
Expand All @@ -18,11 +19,12 @@ def evaluate(
relax: bool = True,
energies: list[float] | None = None,
reference: ReferenceDataset | None = None,
structure_matcher: OrderedStructureMatcher
| DisorderedStructureMatcher = DefaultDisorderedStructureMatcher(),
structure_matcher: (
OrderedStructureMatcher | DisorderedStructureMatcher
) = DefaultDisorderedStructureMatcher(),
save_as: str | None = None,
potential_load_path: str | None = None,
device: str = "cuda",
device: str = str(get_device()),
) -> dict[str, float | int]:
"""Evaluate the structures against a reference dataset.

Expand All @@ -41,18 +43,20 @@ def evaluate(
if relax and energies is not None:
raise ValueError("Cannot accept energies if relax is True.")
if relax:
relaxed_structures, energies = relax_structures(structures, device=device, load_path=potential_load_path)
relaxed_structures, energies = relax_structures(
structures, device=device, load_path=potential_load_path
)
else:
relaxed_structures = structures
evaluator = MetricsEvaluator.from_structures_and_energies(
structures=relaxed_structures,
structures=relaxed_structures,
energies=energies,
original_structures=structures,
reference=reference,
structure_matcher=structure_matcher
structure_matcher=structure_matcher,
)
return evaluator.compute_metrics(
metrics = evaluator.available_metrics,
save_as = save_as,
pretty_print = True,
metrics=evaluator.available_metrics,
save_as=save_as,
pretty_print=True,
)
20 changes: 16 additions & 4 deletions mattergen/evaluation/utils/relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,31 @@
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor

from mattergen.common.utils.globals import get_device

logger = get_logger()
logger.level("ERROR")

def relax_atoms(atoms: list[Atoms], device: str = "cuda", load_path: str = None, **kwargs) -> tuple[list[Atoms], np.ndarray]:
potential = Potential.from_checkpoint(device=device, load_path=load_path, load_training_state=False)

def relax_atoms(
atoms: list[Atoms], device: str = str(get_device()), load_path: str = None, **kwargs
) -> tuple[list[Atoms], np.ndarray]:
potential = Potential.from_checkpoint(
device=device, load_path=load_path, load_training_state=False
)
batch_relaxer = BatchRelaxer(potential=potential, filter="EXPCELLFILTER", **kwargs)
relaxation_trajectories = batch_relaxer.relax(atoms)
relaxed_atoms = [t[-1] for t in relaxation_trajectories.values()]
total_energies = np.array([a.info['total_energy'] for a in relaxed_atoms])
total_energies = np.array([a.info["total_energy"] for a in relaxed_atoms])
return relaxed_atoms, total_energies


def relax_structures(structures: Structure | list[Structure], device: str = "cuda", load_path: str = None, **kwargs) -> tuple[list[Structure], np.ndarray]:
def relax_structures(
structures: Structure | list[Structure],
device: str = str(get_device()),
load_path: str = None,
**kwargs
) -> tuple[list[Structure], np.ndarray]:
if isinstance(structures, Structure):
structures = [structures]
atoms = [AseAtomsAdaptor.get_atoms(s) for s in structures]
Expand Down
4 changes: 2 additions & 2 deletions mattergen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
make_structure,
save_structures,
)
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device
from mattergen.diffusion.lightning_module import DiffusionLightningModule
from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector

Expand Down Expand Up @@ -332,7 +332,7 @@ def prepare(self) -> None:
if self._model is not None:
return
model = load_model_diffusion(self.checkpoint_info)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(get_device())
self._model = model
self._cfg = self.checkpoint_info.config

Expand Down
24 changes: 7 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ dependencies = [
"setuptools",
"SMACT",
"sympy>=1.11.1",
"torch==2.2.1+cu118",
"torchvision==0.17.1+cu118",
"torchaudio==2.2.1+cu118",
"torch==2.4.1",
"torchvision==0.19.1",
"torchaudio==2.4.1",
"torch_cluster",
"torch_geometric>=2.5",
"torch_scatter",
Expand All @@ -67,17 +67,7 @@ dependencies = [
include = ["mattergen*"]

[tool.uv.sources]
torch = { index = "pytorch" }
torchvision = { index = "pytorch" }
torchaudio = { index = "pytorch" }
pyg-lib = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/pyg_lib-0.4.0%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
torch_cluster = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_cluster-1.6.3%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
torch_scatter = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_scatter-2.1.2%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }
torch_sparse = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_sparse-0.6.18%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" }

[[tool.uv.index]]
name = "pytorch"
url = "https://download.pytorch.org/whl/cu118"
explicit = true


pyg-lib = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/pyg_lib-0.4.0%2Bpt24-cp310-cp310-macosx_14_0_universal2.whl" }
torch_cluster = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_cluster-1.6.3-cp310-cp310-macosx_10_9_universal2.whl" }
torch_sparse = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_sparse-0.6.18-cp310-cp310-macosx_11_0_universal2.whl" }
torch_scatter = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_scatter-2.1.2-cp310-cp310-macosx_10_9_universal2.whl" }
Loading
Loading