diff --git a/src/anomaly_detector/multivariate/model.py b/src/anomaly_detector/multivariate/model.py index d43d785..2f1609c 100644 --- a/src/anomaly_detector/multivariate/model.py +++ b/src/anomaly_detector/multivariate/model.py @@ -151,7 +151,7 @@ def save_checkpoint(self): self.model.to(self.config.device) def load_checkpoint(self, model_path): - ckpt = torch.load(model_path) + ckpt = torch.load(model_path, weights_only=True) self.config = ckpt["config"] self.model = MultivariateGraphAttnDetector(self.config) self.model.load_state_dict(ckpt["state_dict"])