Skip to content

Commit c81e0f2

Browse files
author
Daniel Zuegner
committed
generator fix
1 parent 6ed34b9 commit c81e0f2

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

mattergen/generator.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,10 @@ def __post_init__(self) -> None:
206206
"please add it to mattergen.common.data.num_atoms_distribution.NUM_ATOMS_DISTRIBUTIONS."
207207
)
208208
if len(self.target_compositions_dict) > 0:
209-
assert (
210-
self.cfg.lightning_module.diffusion_module.loss_fn.weights.get(
211-
"atomic_numbers", 0.0
212-
)
213-
== 0.0
214-
and "atomic_numbers"
215-
not in self.cfg.lightning_module.diffusion_module.corruption.discrete_corruptions
209+
assert self.cfg.lightning_module.diffusion_module.loss_fn.weights.get(
210+
"atomic_numbers", 0.0
211+
) == 0.0 and "atomic_numbers" not in self.cfg.lightning_module.diffusion_module.corruption.get(
212+
"discrete_corruptions", {}
216213
), "Input model appears to have been trained for crystal generation (i.e., with atom type denoising), not crystal structure prediction. Please use a model trained for crystal structure prediction instead."
217214
sampling_cfg = self._load_sampling_config(
218215
sampling_config_name=self.sampling_config_name,

0 commit comments

Comments
 (0)