Skip to content

Commit

Permalink
download checkpoints from HF hub, cli entry points (#39)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Zuegner <[email protected]>
  • Loading branch information
danielzuegner and Daniel Zuegner authored Jan 30, 2025
1 parent baae75e commit 1abe727
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 43 deletions.
45 changes: 20 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ We provide checkpoints of an unconditional base version of MatterGen as well as
* `dft_mag_density_hhi_score`: fine-tuned model jointly conditioned on magnetic density from DFT and HHI score
* `chemical_system_energy_above_hull`: fine-tuned model jointly conditioned on chemical system and energy above hull from DFT

The checkpoints are located at `checkpoints/<model_name>`.
The checkpoints are located at `checkpoints/<model_name>` and are also available on [Hugging Face](https://huggingface.co/microsoft/mattergen).

> [!NOTE]
> The checkpoints provided were re-trained using this repository, i.e., are not identical to the ones used in the paper. Hence, results may slightly deviate from those in the publication.
Expand All @@ -91,30 +91,30 @@ The checkpoints are located at `checkpoints/<model_name>`.
### Unconditional generation
To sample from the pre-trained base model, run the following command.
```bash
export MODEL_PATH=checkpoints/mattergen_base # Or provide your own model
export MODEL_NAME=mattergen_base
export RESULTS_PATH=results/ # Samples will be written to this directory
git lfs pull -I $MODEL_PATH --exclude="" # first download the checkpoint file from Git LFS

# generate batch_size * num_batches samples
python scripts/generate.py $RESULTS_PATH $MODEL_PATH --batch_size=16 --num_batches 1
mattergen-generate $RESULTS_PATH --pretrained-name=$MODEL_NAME --batch_size=16 --num_batches 1
```
This script will write the following files into `$RESULTS_PATH`:
* `generated_crystals_cif.zip`: a ZIP file containing a single `.cif` file per generated structure.
* `generated_crystals.extxyz`, a single file containing the individual generated structures as frames.
* If `--record-trajectories == True` (default): `generated_trajectories.zip`: a ZIP file containing a `.extxyz` file per generated structure, which contains the full denoising trajectory for each individual structure.
> [!TIP]
> For best efficiency, increase the batch size to the largest your GPU can sustain without running out of memory.
> [!NOTE]
> To sample from a model you've trained yourself, replace `--pretrained-name=$MODEL_NAME` with `--model_path=$MODEL_PATH`, filling in your model's location for `$MODEL_PATH`.
### Property-conditioned generation
With a fine-tuned model, you can generate materials conditioned on a target property.
For example, to sample from the model trained on magnetic density, you can run the following command.
```bash
export MODEL_NAME=dft_mag_density
export MODEL_PATH="checkpoints/$MODEL_NAME" # Or provide your own model
export RESULTS_PATH="results/$MODEL_NAME/" # Samples will be written to this directory, e.g., `results/dft_mag_density`
git lfs pull -I $MODEL_PATH --exclude="" # first download the checkpoint file from Git LFS

# Generate conditional samples with a target magnetic density of 0.15
python scripts/generate.py $RESULTS_PATH $MODEL_PATH --batch_size=16 --checkpoint_epoch=last --properties_to_condition_on="{'dft_mag_density': 0.15}" --diffusion_guidance_factor=2.0
mattergen-generate $RESULTS_PATH --pretrained-name=$MODEL_NAME --batch_size=16 --properties_to_condition_on="{'dft_mag_density': 0.15}" --diffusion_guidance_factor=2.0
```
> [!TIP]
> The argument `--diffusion-guidance-factor` corresponds to the $\gamma$ parameter in [classifier-free diffusion guidance](https://sander.ai/2022/05/26/guidance.html). Setting it to zero corresponds to unconditional generation, and increasing it further tends to produce samples which adhere more to the input property values, though at the expense of diversity and realism of samples.
Expand All @@ -124,17 +124,15 @@ You can also generate materials conditioned on more than one property. For insta
Adapt the following command to your specific needs:
```bash
export MODEL_NAME=chemical_system_energy_above_hull
export MODEL_PATH="checkpoints/$MODEL_NAME" # Or provide your own model
export RESULTS_PATH="results/$MODEL_NAME/" # Samples will be written to this directory, e.g., `results/dft_mag_density`
git lfs pull -I $MODEL_PATH --exclude="" # first download the checkpoint file from Git LFS
python scripts/generate.py $RESULTS_PATH $MODEL_PATH --batch_size=16 --checkpoint_epoch=last --properties_to_condition_on="{'energy_above_hull': 0.05, 'chemical_system': 'Li-O'}" --diffusion_guidance_factor=2.0
mattergen-generate $RESULTS_PATH --pretrained-name=$MODEL_NAME --batch_size=16 --properties_to_condition_on="{'energy_above_hull': 0.05, 'chemical_system': 'Li-O'}" --diffusion_guidance_factor=2.0
```
## Evaluation

Once you have generated a list of structures contained in `$RESULTS_PATH` (either using MatterGen or another method), you can relax the structures using the default MatterSim machine learning force field (see [repository](https://github.com/microsoft/mattersim)) and compute novelty, uniqueness, stability (using energy estimated by MatterSim), and other metrics via the following command:
```bash
git lfs pull -I data-release/alex-mp/reference_MP2020correction.gz --exclude="" # first download the reference dataset from Git LFS
python scripts/evaluate.py --structures_path=$RESULTS_PATH --relax=True --structure_matcher='disordered' --save_as="$RESULTS_PATH/metrics.json"
mattergen-evaluate --structures_path=$RESULTS_PATH --relax=True --structure_matcher='disordered' --save_as="$RESULTS_PATH/metrics.json"
```
This script will write `metrics.json` containing the metric results to `$RESULTS_PATH` and will print it to your console.
> [!IMPORTANT]
Expand All @@ -147,7 +145,7 @@ This script will write `metrics.json` containing the metric results to `$RESULTS
If, instead, you have relaxed the structures and obtained the relaxed total energies via another mean (e.g., DFT), you can evaluate the metrics via:
```bash
git lfs pull -I data-release/alex-mp/reference_MP2020correction.gz --exclude="" # first download the reference dataset from Git LFS
python scripts/evaluate.py --structures_path=$RESULTS_PATH --energies_path='energies.npy' --relax=False --structure_matcher='disordered' --save_as='metrics'
mattergen-evaluate --structures_path=$RESULTS_PATH --energies_path='energies.npy' --relax=False --structure_matcher='disordered' --save_as='metrics'
```
This script will try to read structures from disk in the following precedence order:
* If `$RESULTS_PATH` points to a `.xyz` or `.extxyz` file, it will read it directly and assume each frame is a different structure.
Expand All @@ -165,7 +163,7 @@ You can run the following command for `mp_20`:
# Download file from LFS
git lfs pull -I data-release/mp-20/ --exclude=""
unzip data-release/mp-20/mp_20.zip -d datasets
python scripts/csv_to_dataset.py --csv-folder datasets/mp_20/ --dataset-name mp_20 --cache-folder datasets/cache
csv-to-dataset --csv-folder datasets/mp_20/ --dataset-name mp_20 --cache-folder datasets/cache
```
You will get preprocessed data files in `datasets/cache/mp_20`.

Expand All @@ -174,15 +172,15 @@ To preprocess our larger `alex_mp_20` dataset, run:
# Download file from LFS
git lfs pull -I data-release/alex-mp/alex_mp_20.zip --exclude=""
unzip data-release/alex-mp/alex_mp_20.zip -d datasets
python scripts/csv_to_dataset.py --csv-folder datasets/alex_mp_20/ --dataset-name alex_mp_20 --cache-folder datasets/cache
csv-to-dataset --csv-folder datasets/alex_mp_20/ --dataset-name alex_mp_20 --cache-folder datasets/cache
```
This will take some time (~1h). You will get preprocessed data files in `datasets/cache/alex_mp_20`.

### Training
You can train the MatterGen base model on `mp_20` using the following command.

```bash
python scripts/run.py data_module=mp_20 ~trainer.logger
mattergen-train data_module=mp_20 ~trainer.logger
```
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.
Expand All @@ -196,7 +194,7 @@ The validation loss (`loss_val`) should reach 0.4 after 360 epochs (about 80k st
To train the MatterGen base model on `alex_mp_20`, use the following command:
```bash
python scripts/run.py data_module=alex_mp_20 ~trainer.logger trainer.accumulate_grad_batches=4
mattergen-train data_module=alex_mp_20 ~trainer.logger trainer.accumulate_grad_batches=4
```
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.
Expand All @@ -212,15 +210,13 @@ To sample from this model, pass `--target_compositions=[{"<element1>": <number_o
An example composition could be `--target_compositions=[{"Na": 1, "Cl": 1}]`.
### Fine-tuning on property data

Assume that you have a MatterGen base model at `$MODEL_PATH` (e.g., `checkpoints/mattergen_base`). You can fine-tune MatterGen using the following command.
You can fine-tune the MatterGen base model using the following command.

```bash
export PROPERTY=dft_mag_density
export MODEL_PATH=checkpoints/mattergen_base
git lfs pull -I $MODEL_PATH --exclude="" # first download the checkpoint file from Git LFS
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY=$PROPERTY ~trainer.logger data_module.properties=["$PROPERTY"]
mattergen-finetune adapter.pretrained_name=mattergen_base data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY=$PROPERTY ~trainer.logger data_module.properties=["$PROPERTY"]
```
`dft_mag_density` denotes the target property for fine-tuning.
`dft_mag_density` denotes the target property for fine-tuning. You can also fine-tune a model you've trained yourself by **replacing** `adapter.pretrained_name=mattergen_base` with `adapter.model_path=$MODEL_PATH`, filling in your model's location for `$MODEL_PATH`.
> [!NOTE]
> For Apple Silicon training, add `~trainer.strategy trainer.accelerator=mps` to the above command.
Expand All @@ -234,9 +230,8 @@ You can also fine-tune MatterGen on multiple properties. For instance, to fine-t
```bash
export PROPERTY1=dft_mag_density
export PROPERTY2=dft_band_gap
export MODEL_PATH=checkpoints/mattergen_base
git lfs pull -I $MODEL_PATH --exclude="" # first download the checkpoint file from Git LFS
python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY1=$PROPERTY1 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY2=$PROPERTY2 ~trainer.logger data_module.properties=["$PROPERTY1","$PROPERTY2"]
export MODEL_NAME=mattergen_base
mattergen-finetune adapter.pretrained_name=$MODEL_NAME data_module=mp_20 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY1=$PROPERTY1 +lightning_module/diffusion_module/model/[email protected]_embeddings_adapt.$PROPERTY2=$PROPERTY2 ~trainer.logger data_module.properties=["$PROPERTY1","$PROPERTY2"]
```
> [!TIP]
> Add more properties analogously by adding these overrides:
Expand All @@ -250,7 +245,7 @@ python scripts/finetune.py adapter.model_path=$MODEL_PATH data_module=mp_20 +lig
You may also fine-tune MatterGen on your own property data. Essentially what you need is a property value (typically `float`) for a subset of the data you want to train on (e.g., `alex_mp_20`). Proceed as follows:
1. Add the name of your property to the `PROPERTY_SOURCE_IDS` list inside [`mattergen/mattergen/common/utils/globals.py`](mattergen/mattergen/common/utils/globals.py).
2. Add a new column with this name to the dataset(s) you want to train on, e.g., `datasets/alex_mp_20/train.csv` and `datasets/alex_mp_20/val.csv` (requires you to have followed the [pre-processing steps](#pre-process-a-dataset-for-training)).
3. Re-run the CSV to dataset script `python scripts/csv_to_dataset.py --csv-folder datasets/<MY_DATASET>/ --dataset-name <MY_DATASET> --cache-folder datasets/cache`, substituting your dataset name for `MY_DATASET`.
3. Re-run the CSV to dataset script `csv-to-dataset --csv-folder datasets/<MY_DATASET>/ --dataset-name <MY_DATASET> --cache-folder datasets/cache`, substituting your dataset name for `MY_DATASET`.
4. Add a `<your_property>.yaml` config file to [`mattergen/conf/lightning_module/diffusion_module/model/property_embeddings`](mattergen/conf/lightning_module/diffusion_module/model/property_embeddings). If you are adding a float-valued property, you may copy an existing configuration, e.g., [`dft_mag_density.yaml`](mattergen/conf/lightning_module/diffusion_module/model/property_embeddings/dft_mag_density.yaml). More complicated properties will require you to create your own custom `PropertyEmbedding` subclass, e.g., see the [`space_group`](mattergen/conf/lightning_module/diffusion_module/model/property_embeddings/space_group.yaml) or [`chemical_system`](mattergen/conf/lightning_module/diffusion_module/model/property_embeddings/chemical_system.yaml) configs.
5. Follow the [instructions for fine-tuning](#fine-tuning-on-property-data) and reference your own property in the same way as we used the existing properties like `dft_mag_density`.

Expand Down
35 changes: 35 additions & 0 deletions mattergen/common/utils/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,21 @@
from typing import Any, Literal

import numpy as np
from huggingface_hub import hf_hub_download
from hydra import compose, initialize_config_dir
from omegaconf import DictConfig

PRETRAINED_MODEL_NAME = Literal[
"mattergen_base",
"chemical_system",
"space_group",
"dft_mag_density",
"dft_band_gap",
"ml_bulk_modulus",
"dft_mag_density_hhi_score",
"chemical_system_energy_above_hull",
]


def find_local_files(local_path: str, glob: str = "*", relative: bool = False) -> list[str]:
"""
Expand Down Expand Up @@ -41,6 +53,29 @@ class MatterGenCheckpointInfo:
split: str = "val"
strict_checkpoint_loading: bool = True

@classmethod
def from_hf_hub(
cls,
model_name: PRETRAINED_MODEL_NAME,
repository_name: str = "microsoft/mattergen",
config_overrides: list[str] = None,
):
"""
Instantiate a MatterGenCheckpointInfo object from a model hosted on the Hugging Face Hub.
"""
hf_hub_download(
repo_id=repository_name, filename=f"checkpoints/{model_name}/checkpoints/last.ckpt"
)
config_path = hf_hub_download(
repo_id=repository_name, filename=f"checkpoints/{model_name}/config.yaml"
)
return cls(
model_path=Path(config_path).parent,
config_overrides=config_overrides or [],
load_epoch="last",
)

def as_dict(self) -> dict[str, Any]:
d = asdict(self)
d["model_path"] = str(self.model_path) # we cannot put Path object in mongo DB
Expand Down
3 changes: 2 additions & 1 deletion mattergen/conf/adapter/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
model_path: ${oc.env:MAP_INPUT_DIR}
pretrained_name: mattergen_base
model_path: null
load_epoch: last
full_finetuning: true

Expand Down
2 changes: 1 addition & 1 deletion mattergen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __post_init__(self) -> None:
f"but got {self.num_atoms_distribution}. To add your own distribution, "
"please add it to mattergen.common.data.num_atoms_distribution.NUM_ATOMS_DISTRIBUTIONS."
)
if len(self.target_compositions_dict) > 0:
if self.target_compositions_dict:
assert self.cfg.lightning_module.diffusion_module.loss_fn.weights.get(
"atomic_numbers", 0.0
) == 0.0 and "atomic_numbers" not in self.cfg.lightning_module.diffusion_module.corruption.get(
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"contextlib2",
"emmet-core>=0.84.2", # keep up-to-date together with pymatgen, atomate2
"fire", # see https://github.com/google/python-fire
"huggingface-hub",
"hydra-core==1.3.1",
"hydra-joblib-launcher==1.1.5",
"jupyterlab>=4.2.5",
Expand Down Expand Up @@ -79,3 +80,11 @@ torch_sparse = { url = "https://data.pyg.org/whl/torch-2.2.0%2Bcu118/torch_spars
name = "pytorch"
url = "https://download.pytorch.org/whl/cu118"
explicit = true


[project.scripts]
mattergen-generate = "scripts.generate:_main"
mattergen-train = "scripts.run:mattergen_main"
mattergen-finetune = "scripts.finetune:mattergen_finetune"
mattergen-evaluate = "scripts.evaluate:_main"
csv-to-dataset = "scripts.csv_to_dataset:main"
7 changes: 6 additions & 1 deletion scripts/csv_to_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from mattergen.common.data.dataset import CrystalDataset
from mattergen.common.globals import PROJECT_ROOT

if __name__ == "__main__":

def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--csv-folder",
Expand Down Expand Up @@ -36,3 +37,7 @@
csv_path=f"{args.csv_folder}/{file}",
cache_path=f"{args.cache_folder}/{args.dataset_name}/{file.split('.')[0]}",
)


if __name__ == "__main__":
main()
6 changes: 5 additions & 1 deletion scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,9 @@ def main(
print(json.dumps(metrics, indent=2))


if __name__ == "__main__":
def _main():
fire.Fire(main)


if __name__ == "__main__":
_main()
19 changes: 14 additions & 5 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,24 @@
def init_adapter_lightningmodule_from_pretrained(
adapter_cfg: DictConfig, lightning_module_cfg: DictConfig
) -> Tuple[pl.LightningModule, DictConfig]:
assert adapter_cfg.model_path is not None, "model_path must be provided."

model_path = Path(hydra.utils.to_absolute_path(adapter_cfg.model_path))
ckpt_info = MatterGenCheckpointInfo(model_path, adapter_cfg.load_epoch)
if adapter_cfg.model_path is not None:
if adapter_cfg.pretrained_name is not None:
logger.warning(
"pretrained_name is provided, but will be ignored since model_path is also provided."
)
model_path = Path(hydra.utils.to_absolute_path(adapter_cfg.model_path))
ckpt_info = MatterGenCheckpointInfo(model_path, adapter_cfg.load_epoch)
elif adapter_cfg.pretrained_name is not None:
assert (
adapter_cfg.model_path is None
), "model_path must be None when pretrained_name is provided."
ckpt_info = MatterGenCheckpointInfo.from_hf_hub(adapter_cfg.pretrained_name)

ckpt_path = ckpt_info.checkpoint_path

version_root_path = Path(ckpt_path).relative_to(model_path).parents[1]
config_path = model_path / version_root_path
version_root_path = Path(ckpt_path).relative_to(ckpt_info.model_path).parents[1]
config_path = ckpt_info.model_path / version_root_path

# load pretrained model config.
if (config_path / "config.yaml").exists():
Expand Down
Loading

0 comments on commit 1abe727

Please sign in to comment.