From 826d109c57a9544242d32237a0146c0848d3edb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Z=C3=BCgner?= Date: Wed, 22 Jan 2025 13:14:40 +0100 Subject: [PATCH] Add experimental Apple Silicon support (#35) Adds experimental Apple Silicon support. --------- Co-authored-by: Daniel Zuegner --- README.md | 40 +++++++++++-- mattergen/common/gemnet/gemnet.py | 20 ++++--- mattergen/common/utils/data_utils.py | 14 +---- mattergen/common/utils/eval_utils.py | 3 +- mattergen/common/utils/globals.py | 24 ++++++++ mattergen/common/utils/ocp_graph_utils.py | 13 +++- mattergen/evaluation/evaluate.py | 22 ++++--- mattergen/evaluation/utils/relaxation.py | 20 +++++-- mattergen/generator.py | 4 +- pyproject.toml | 24 +++----- pyproject_apple_silicon.toml | 73 +++++++++++++++++++++++ scripts/evaluate.py | 3 +- scripts/finetune.py | 4 +- 13 files changed, 199 insertions(+), 65 deletions(-) create mode 100644 pyproject_apple_silicon.toml diff --git a/README.md b/README.md index 7541104..35d4402 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ MatterGen is a generative model for inorganic materials design across the periodic table that can be fine-tuned to steer the generation towards a wide range of property constraints. + ## Table of Contents - [Installation](#installation) - [Get started with a pre-trained model](#get-started-with-a-pre-trained-model) @@ -28,9 +29,10 @@ MatterGen is a generative model for inorganic materials design across the period ## Installation + The easiest way to install prerequisites is via [uv](https://docs.astral.sh/uv/), a fast Python package and project manager. -The MatterGen environment can be installed via the following command: +The MatterGen environment can be installed via the following command (assumes you are running Linux and have a CUDA GPU): ```bash pip install uv uv venv .venv --python 3.10 @@ -44,6 +46,24 @@ git lfs --version ``` If this prints some version like `git-lfs/3.0.2 (GitHub; linux amd64; go 1.18.1)`, you can skip the following step. +### Apple Silicon +> [!WARNING] +> Running MatterGen on Apple Silicon is **experimental**. Use at your own risk. +> Further, you need to run `export PYTORCH_ENABLE_MPS_FALLBACK=1` before any training or generation run. + +To install the environment for Apple Silicon, run these commands: +```bash +cp pyproject.toml pyproject.linux.toml +mv pyproject_apple_silicon.toml pyproject.toml +pip install uv +uv venv .venv --python 3.10 +source .venv/bin/activate +uv pip install -e . +export PYTORCH_ENABLE_MPS_FALLBACK=1 # required to run MatterGen on Apple Silicon +``` + + + ### Install Git LFS If Git LFS was not installed before you cloned this repo, you can install it and download the missing files via: ```bash @@ -156,6 +176,8 @@ You can train the MatterGen base model on `mp_20` using the following command. ```bash python scripts/run.py data_module=mp_20 ~trainer.logger ``` +> [!NOTE] +> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command. The validation loss (`loss_val`) should reach 0.4 after 360 epochs (about 80k steps). The output checkpoints can be found at `outputs/singlerun/${now:%Y-%m-%d}/${now:%H-%M-%S}`. We call this folder `$MODEL_PATH` for future reference. > [!NOTE] @@ -168,6 +190,9 @@ To train the MatterGen base model on `alex_mp_20`, use the following command: ```bash python scripts/run.py data_module=alex_mp_20 ~trainer.logger trainer.accumulate_grad_batches=4 ``` +> [!NOTE] +> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command. + > [!TIP] > Note that a single GPU's memory usually is not enough for the batch size of 512, hence we accumulate gradients over 4 batches. If you still run out of memory, increase this further. @@ -186,8 +211,10 @@ export PROPERTY=dft_mag_density export MODEL_PATH=checkpoints/mattergen_base python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/property_embeddings@adapter.adapter.property_embeddings_adapt.$PROPERTY=$PROPERTY ~trainer.logger data_module.properties=["$PROPERTY"] ``` - `dft_mag_density` denotes the target property for fine-tuning. +> [!NOTE] +> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command. + > [!TIP] > You can select any property that is available in the dataset. See [`mattergen/conf/data_module/mp_20.yaml`](mattergen/conf/data_module/mp_20.yaml) or [`mattergen/conf/data_module/alex_mp_20.yaml`](mattergen/conf/data_module/alex_mp_20.yaml) for the list of supported properties. You can also add your own custom property data. See [below](#fine-tune-on-your-own-property-data) for instructions. @@ -199,12 +226,15 @@ You can also fine-tune MatterGen on multiple properties. For instance, to fine-t export PROPERTY1=dft_mag_density export PROPERTY2=dft_band_gap export MODEL_PATH=checkpoints/mattergen_base -python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/property_embeddings@adapter.adapter.property_embeddings_adapt.$PROPERTY1=$PROPERTY1 +lightning_module/diffusion_module/model/property_embeddings@adapter.adapter.property_embeddings_adapt.$PROPERTY2=$PROPERTY2 ~trainer.logger data_module.properties=["$PROPERTY1", "$PROPERTY2"] +python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/property_embeddings@adapter.adapter.property_embeddings_adapt.$PROPERTY1=$PROPERTY1 +lightning_module/diffusion_module/model/property_embeddings@adapter.adapter.property_embeddings_adapt.$PROPERTY2=$PROPERTY2 ~trainer.logger data_module.properties=["$PROPERTY1","$PROPERTY2"] ``` > [!TIP] > Add more properties analogously by adding these overrides: > 1. `+lightning_module/diffusion_module/model/property_embeddings@adapter.adapter.property_embeddings_adapt.=` -> 2. Add `` to the `data_module.properties=["$PROPERTY1", "$PROPERTY2", ..., ]` override. +> 2. Add `` to the `data_module.properties=["$PROPERTY1","$PROPERTY2",...,]` override. + +> [!NOTE] +> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command. #### Fine-tune on your own property data You may also fine-tune MatterGen on your own property data. Essentially what you need is a property value (typically `float`) for a subset of the data you want to train on (e.g., `alex_mp_20`). Proceed as follows: @@ -261,4 +291,4 @@ If you have any questions not covered here, please create an issue or contact th Materials Design team at [ai4s-materials@microsoft.com](mailto:ai4s-materials@microsoft.com). We would appreciate your feedback and would like to know how MatterGen has been beneficial to your research. -Please share your experiences with us at [ai4s-materials@microsoft.com](mailto:ai4s-materials@microsoft.com). +Please share your experiences with us at [ai4s-materials@microsoft.com](mailto:ai4s-materials@microsoft.com). \ No newline at end of file diff --git a/mattergen/common/gemnet/gemnet.py b/mattergen/common/gemnet/gemnet.py index 7ee9660..eefb3b4 100644 --- a/mattergen/common/gemnet/gemnet.py +++ b/mattergen/common/gemnet/gemnet.py @@ -34,7 +34,7 @@ lattice_params_to_matrix_torch, radius_graph_pbc, ) -from mattergen.common.utils.globals import MODELS_PROJECT_ROOT +from mattergen.common.utils.globals import MODELS_PROJECT_ROOT, get_device, get_pyg_device from mattergen.common.utils.lattice_score import edge_score_to_lattice_score_frac_symmetric @@ -381,17 +381,19 @@ def get_triplets( value = torch.arange(idx_s.size(0), device=idx_s.device, dtype=idx_s.dtype) # Possibly contains multiple copies of the same edge (for periodic interactions) + pyg_device = get_pyg_device() + torch_device = get_device() adj = SparseTensor( - row=idx_t, - col=idx_s, - value=value, - sparse_sizes=(num_atoms, num_atoms), + row=idx_t.to(pyg_device), + col=idx_s.to(pyg_device), + value=value.to(pyg_device), + sparse_sizes=(num_atoms.to(pyg_device), num_atoms.to(pyg_device)), ) - adj_edges = adj[idx_t] + adj_edges = adj[idx_t.to(pyg_device)].to(torch_device) # Edge indices (b->a, c->a) for triplets. - id3_ba = adj_edges.storage.value() - id3_ca = adj_edges.storage.row() + id3_ba = adj_edges.storage.value().to(torch_device) + id3_ca = adj_edges.storage.row().to(torch_device) # Remove self-loop triplets # Compare edge indices, not atom indices to correctly handle periodic interactions @@ -773,4 +775,4 @@ def forward( @property def num_params(self): - return sum(p.numel() for p in self.parameters()) + return sum(p.numel() for p in self.parameters()) \ No newline at end of file diff --git a/mattergen/common/utils/data_utils.py b/mattergen/common/utils/data_utils.py index 51adc26..4873a2d 100644 --- a/mattergen/common/utils/data_utils.py +++ b/mattergen/common/utils/data_utils.py @@ -361,19 +361,7 @@ def compute_lattice_polar_decomposition(lattice_matrix: torch.Tensor) -> torch.T # lattice_matrix: [batch_size, 3, 3] # Computes the (unique) symmetric lattice matrix that is equivalent (up to rotation) to the input lattice. - if lattice_matrix.device.type == "cuda": - # there is an issue running torch.linalg.svd on cuda tensors with driver version 450.* - - try: - W, S, V_transp = torch.linalg.svd(lattice_matrix) - except torch._C._LinAlgError: - # move to cpu and try again - W, S, V_transp = torch.linalg.svd(lattice_matrix.to("cpu")) - W = W.to(lattice_matrix.device.type) - S = S.to(lattice_matrix.device.type) - V_transp = V_transp.to(lattice_matrix.device.type) - else: - W, S, V_transp = torch.linalg.svd(lattice_matrix) + W, S, V_transp = torch.linalg.svd(lattice_matrix) S_square = torch.diag_embed(S) V = V_transp.transpose(1, 2) U = W @ V_transp diff --git a/mattergen/common/utils/eval_utils.py b/mattergen/common/utils/eval_utils.py index 8de0366..79e33b5 100644 --- a/mattergen/common/utils/eval_utils.py +++ b/mattergen/common/utils/eval_utils.py @@ -20,6 +20,7 @@ GENERATED_CRYSTALS_ZIP_FILE_NAME, ) from mattergen.common.utils.data_classes import MatterGenCheckpointInfo +from mattergen.common.utils.globals import get_device from mattergen.diffusion.lightning_module import DiffusionLightningModule # logging @@ -54,7 +55,7 @@ def load_model_diffusion( try: model, incompatible_keys = DiffusionLightningModule.load_from_checkpoint_and_config( ckpt, - map_location=torch.device("cpu") if not torch.cuda.is_available() else None, + map_location=get_device(), config=cfg.lightning_module, strict=args.strict_checkpoint_loading, ) diff --git a/mattergen/common/utils/globals.py b/mattergen/common/utils/globals.py index d6b252b..e35ec6a 100644 --- a/mattergen/common/utils/globals.py +++ b/mattergen/common/utils/globals.py @@ -6,10 +6,32 @@ 2. It registers a new resolver for OmegaConf, `eval`, which allows us to use `eval` in our config files. """ import os +from functools import lru_cache from pathlib import Path +import torch from omegaconf import OmegaConf + +@lru_cache +def get_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda") + if torch.backends.mps.is_available(): + return torch.device("mps") + return torch.device("cpu") + + +@lru_cache +def get_pyg_device() -> torch.device: + """ + Some operations of pyg don't work on MPS, so fall back to CPU. + """ + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + MODELS_PROJECT_ROOT = Path(__file__).resolve().parents[2] print(f"MODELS_PROJECT_ROOT: {MODELS_PROJECT_ROOT}") @@ -109,6 +131,8 @@ 83, ] MAX_ATOMIC_NUM = 100 + + # Set `eval` resolver def try_eval(s): """This is a custom resolver for OmegaConf that allows us to use `eval` in our config files diff --git a/mattergen/common/utils/ocp_graph_utils.py b/mattergen/common/utils/ocp_graph_utils.py index f816e1d..3916933 100644 --- a/mattergen/common/utils/ocp_graph_utils.py +++ b/mattergen/common/utils/ocp_graph_utils.py @@ -12,6 +12,8 @@ import torch from torch_scatter import segment_coo, segment_csr +from mattergen.common.utils.globals import get_pyg_device + def get_pbc_distances( pos: torch.Tensor, @@ -272,14 +274,21 @@ def get_max_neighbors_mask( # Get number of neighbors # segment_coo assumes sorted index ones = index.new_ones(1).expand_as(index) - num_neighbors = segment_coo(ones, index, dim_size=num_atoms) + # required because PyG does not support MPS for the segment_coo operation yet. + pyg_device = get_pyg_device() + device_before = ones.device + num_neighbors = segment_coo(ones.to(pyg_device), index.to(pyg_device), dim_size=num_atoms).to( + device_before + ) max_num_neighbors = num_neighbors.max() num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold) # Get number of (thresholded) neighbors per image image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long) image_indptr[1:] = torch.cumsum(natoms, dim=0) - num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr) + num_neighbors_image = segment_csr( + num_neighbors_thresholded.to(pyg_device), image_indptr.to(pyg_device) + ).to(device_before) # If max_num_neighbors is below the threshold, return early if max_num_neighbors <= max_num_neighbors_threshold or max_num_neighbors_threshold <= 0: diff --git a/mattergen/evaluation/evaluate.py b/mattergen/evaluation/evaluate.py index c0f476d..39e6a62 100644 --- a/mattergen/evaluation/evaluate.py +++ b/mattergen/evaluation/evaluate.py @@ -3,6 +3,7 @@ from pymatgen.core.structure import Structure +from mattergen.common.utils.globals import get_device from mattergen.evaluation.metrics.evaluator import MetricsEvaluator from mattergen.evaluation.reference.reference_dataset import ReferenceDataset from mattergen.evaluation.utils.relaxation import relax_structures @@ -18,11 +19,12 @@ def evaluate( relax: bool = True, energies: list[float] | None = None, reference: ReferenceDataset | None = None, - structure_matcher: OrderedStructureMatcher - | DisorderedStructureMatcher = DefaultDisorderedStructureMatcher(), + structure_matcher: ( + OrderedStructureMatcher | DisorderedStructureMatcher + ) = DefaultDisorderedStructureMatcher(), save_as: str | None = None, potential_load_path: str | None = None, - device: str = "cuda", + device: str = str(get_device()), ) -> dict[str, float | int]: """Evaluate the structures against a reference dataset. @@ -41,18 +43,20 @@ def evaluate( if relax and energies is not None: raise ValueError("Cannot accept energies if relax is True.") if relax: - relaxed_structures, energies = relax_structures(structures, device=device, load_path=potential_load_path) + relaxed_structures, energies = relax_structures( + structures, device=device, load_path=potential_load_path + ) else: relaxed_structures = structures evaluator = MetricsEvaluator.from_structures_and_energies( - structures=relaxed_structures, + structures=relaxed_structures, energies=energies, original_structures=structures, reference=reference, - structure_matcher=structure_matcher + structure_matcher=structure_matcher, ) return evaluator.compute_metrics( - metrics = evaluator.available_metrics, - save_as = save_as, - pretty_print = True, + metrics=evaluator.available_metrics, + save_as=save_as, + pretty_print=True, ) diff --git a/mattergen/evaluation/utils/relaxation.py b/mattergen/evaluation/utils/relaxation.py index ede4c39..c15c262 100644 --- a/mattergen/evaluation/utils/relaxation.py +++ b/mattergen/evaluation/utils/relaxation.py @@ -9,19 +9,31 @@ from pymatgen.core import Structure from pymatgen.io.ase import AseAtomsAdaptor +from mattergen.common.utils.globals import get_device + logger = get_logger() logger.level("ERROR") -def relax_atoms(atoms: list[Atoms], device: str = "cuda", load_path: str = None, **kwargs) -> tuple[list[Atoms], np.ndarray]: - potential = Potential.from_checkpoint(device=device, load_path=load_path, load_training_state=False) + +def relax_atoms( + atoms: list[Atoms], device: str = str(get_device()), load_path: str = None, **kwargs +) -> tuple[list[Atoms], np.ndarray]: + potential = Potential.from_checkpoint( + device=device, load_path=load_path, load_training_state=False + ) batch_relaxer = BatchRelaxer(potential=potential, filter="EXPCELLFILTER", **kwargs) relaxation_trajectories = batch_relaxer.relax(atoms) relaxed_atoms = [t[-1] for t in relaxation_trajectories.values()] - total_energies = np.array([a.info['total_energy'] for a in relaxed_atoms]) + total_energies = np.array([a.info["total_energy"] for a in relaxed_atoms]) return relaxed_atoms, total_energies -def relax_structures(structures: Structure | list[Structure], device: str = "cuda", load_path: str = None, **kwargs) -> tuple[list[Structure], np.ndarray]: +def relax_structures( + structures: Structure | list[Structure], + device: str = str(get_device()), + load_path: str = None, + **kwargs +) -> tuple[list[Structure], np.ndarray]: if isinstance(structures, Structure): structures = [structures] atoms = [AseAtomsAdaptor.get_atoms(s) for s in structures] diff --git a/mattergen/generator.py b/mattergen/generator.py index 596c70f..7c95081 100644 --- a/mattergen/generator.py +++ b/mattergen/generator.py @@ -29,7 +29,7 @@ make_structure, save_structures, ) -from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH +from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device from mattergen.diffusion.lightning_module import DiffusionLightningModule from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector @@ -332,7 +332,7 @@ def prepare(self) -> None: if self._model is not None: return model = load_model_diffusion(self.checkpoint_info) - model = model.to("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(get_device()) self._model = model self._cfg = self.checkpoint_info.config diff --git a/pyproject.toml b/pyproject.toml index 4c0f38f..8f44785 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,9 +52,9 @@ dependencies = [ "setuptools", "SMACT", "sympy>=1.11.1", -"torch==2.2.1+cu118", -"torchvision==0.17.1+cu118", -"torchaudio==2.2.1+cu118", +"torch==2.4.1", +"torchvision==0.19.1", +"torchaudio==2.4.1", "torch_cluster", "torch_geometric>=2.5", "torch_scatter", @@ -67,17 +67,7 @@ dependencies = [ include = ["mattergen*"] [tool.uv.sources] -torch = { index = "pytorch" } -torchvision = { index = "pytorch" } -torchaudio = { index = "pytorch" } -pyg-lib = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/pyg_lib-0.4.0%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" } -torch_cluster = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_cluster-1.6.3%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" } -torch_scatter = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_scatter-2.1.2%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" } -torch_sparse = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_sparse-0.6.18%2Bpt22cu118-cp310-cp310-linux_x86_64.whl" } - -[[tool.uv.index]] -name = "pytorch" -url = "https://download.pytorch.org/whl/cu118" -explicit = true - - +pyg-lib = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/pyg_lib-0.4.0%2Bpt24-cp310-cp310-macosx_14_0_universal2.whl" } +torch_cluster = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_cluster-1.6.3-cp310-cp310-macosx_10_9_universal2.whl" } +torch_sparse = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_sparse-0.6.18-cp310-cp310-macosx_11_0_universal2.whl" } +torch_scatter = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_scatter-2.1.2-cp310-cp310-macosx_10_9_universal2.whl" } \ No newline at end of file diff --git a/pyproject_apple_silicon.toml b/pyproject_apple_silicon.toml new file mode 100644 index 0000000..8f44785 --- /dev/null +++ b/pyproject_apple_silicon.toml @@ -0,0 +1,73 @@ +[tool.black] +line-length = 100 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 100 +known_first_party = [ + "mattergen", +] + +[project] +name = "mattergen" +version = "1.0" +requires-python = ">= 3.10" + +dependencies = [ +"ase>=3.22.1", +"autopep8", +"cachetools", +"contextlib2", +"emmet-core>=0.84.2", # keep up-to-date together with pymatgen, atomate2 +"fire", # see https://github.com/google/python-fire +"hydra-core==1.3.1", +"hydra-joblib-launcher==1.1.5", +"jupyterlab>=4.2.5", +"lmdb", +"matplotlib==3.8.4", +"matscipy>=0.7.0", +"mattersim>=1.1", +"monty==2024.7.30 ", # keep up-to-date together with pymatgen, atomate2 +"notebook>=7.2.2", +"numpy<2.0", # pin numpy before breaking changes in 2.0 +"omegaconf==2.3.0", +"pymatgen>=2024.6.4", +"pylint", +"pytest", +"pytorch-lightning==2.0.6", +"setuptools", +"SMACT", +"sympy>=1.11.1", +"torch==2.4.1", +"torchvision==0.19.1", +"torchaudio==2.4.1", +"torch_cluster", +"torch_geometric>=2.5", +"torch_scatter", +"torch_sparse", +"tqdm", +"wandb>=0.10.33", +] + +[tool.setuptools.packages.find] +include = ["mattergen*"] + +[tool.uv.sources] +pyg-lib = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/pyg_lib-0.4.0%2Bpt24-cp310-cp310-macosx_14_0_universal2.whl" } +torch_cluster = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_cluster-1.6.3-cp310-cp310-macosx_10_9_universal2.whl" } +torch_sparse = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_sparse-0.6.18-cp310-cp310-macosx_11_0_universal2.whl" } +torch_scatter = { url = "https://data.pyg.org/whl/torch-2.4.0%2Bcpu/torch_scatter-2.1.2-cp310-cp310-macosx_10_9_universal2.whl" } \ No newline at end of file diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 0672f5a..f0c0e17 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -9,6 +9,7 @@ import numpy as np from mattergen.common.utils.eval_utils import load_structures +from mattergen.common.utils.globals import get_device from mattergen.evaluation.evaluate import evaluate from mattergen.evaluation.utils.structure_matcher import ( DefaultDisorderedStructureMatcher, @@ -25,7 +26,7 @@ def main( potential_load_path: ( Literal["MatterSim-v1.0.0-1M.pth", "MatterSim-v1.0.0-5M.pth"] | None ) = None, - device: str = "cuda", + device: str = str(get_device()), ): structures = load_structures(Path(structures_path)) energies = np.load(energies_path) if energies_path else None diff --git a/scripts/finetune.py b/scripts/finetune.py index eb8db77..f428124 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -16,7 +16,7 @@ from pytorch_lightning.cli import SaveConfigCallback from mattergen.common.utils.data_classes import MatterGenCheckpointInfo -from mattergen.common.utils.globals import MODELS_PROJECT_ROOT +from mattergen.common.utils.globals import MODELS_PROJECT_ROOT, get_device from mattergen.diffusion.run import AddConfigCallback, SimpleParser, maybe_instantiate logger = logging.getLogger(__name__) @@ -81,7 +81,7 @@ def init_adapter_lightningmodule_from_pretrained( lightning_module = hydra.utils.instantiate(lightning_module_cfg) - ckpt: dict = torch.load(ckpt_path) + ckpt: dict = torch.load(ckpt_path, map_location=get_device()) pretrained_dict: OrderedDict = ckpt["state_dict"] scratch_dict: OrderedDict = lightning_module.state_dict() scratch_dict.update(