diff --git a/olive/model/handler/pytorch.py b/olive/model/handler/pytorch.py index e87c9d39b..cb4b53c4b 100644 --- a/olive/model/handler/pytorch.py +++ b/olive/model/handler/pytorch.py @@ -151,7 +151,7 @@ def load_model(self, rank: int = None, cache_model: bool = True) -> "torch.nn.Mo elif self.model_file_format == ModelFileFormat.PYTORCH_TORCH_SCRIPT: model = torch.jit.load(self.model_path) elif self.model_file_format == ModelFileFormat.PYTORCH_ENTIRE_MODEL: - model = torch.load(self.model_path) + model = torch.load(self.model_path, weights_only=False) elif self.model_file_format == ModelFileFormat.PYTORCH_SLICE_GPT_MODEL: model = self._load_slicegpt_model() elif self.model_file_format == ModelFileFormat.PYTORCH_STATE_DICT: diff --git a/olive/olive_config.json b/olive/olive_config.json index 0027ef697..1bee9abfe 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -363,13 +363,13 @@ "extra_dependencies": { "auto-opt": [ "optimum" ], "azureml": [ "azure-ai-ml>=1.11.1", "azure-keyvault-secrets", "azure-identity", "azureml-fsspec" ], - "bnb": [ "bitsandbytes" ], + "bnb": [ "bitsandbytes", "triton" ], "capture-onnx-graph": [ "onnxruntime-genai", "optimum" ], "cpu": [ "onnxruntime" ], "directml": [ "onnxruntime-directml" ], "docker": [ "docker" ], "shared-cache": [ "azure-identity", "azure-storage-blob" ], - "finetune": [ "onnxruntime-genai", "optimum", "accelerate>=0.30.0", "peft", "scipy", "bitsandbytes" ], + "finetune": [ "onnxruntime-genai", "optimum", "accelerate>=0.30.0", "peft", "scipy", "bitsandbytes", "triton" ], "flash-attn": [ "flash_attn" ], "gpu": [ "onnxruntime-gpu" ], "inc": [ "neural-compressor" ], diff --git a/test/requirements-test-gpu.txt b/test/requirements-test-gpu.txt index e8b9592f6..65674e827 100644 --- a/test/requirements-test-gpu.txt +++ b/test/requirements-test-gpu.txt @@ -1,5 +1,5 @@ -r requirements-test.txt auto-gptq autoawq -# only available on Linux currently -bitsandbytes==0.43.3 +bitsandbytes +triton diff --git a/test/unit_test/model/test_pytorch_model.py b/test/unit_test/model/test_pytorch_model.py index ac26454e1..bee910886 100644 --- a/test/unit_test/model/test_pytorch_model.py +++ b/test/unit_test/model/test_pytorch_model.py @@ -58,7 +58,7 @@ def test_load_from_path(torch_load): model = PyTorchModelHandler(model_path="test_path") assert model.load_model() == "dummy_pytorch_model" - torch_load.assert_called_once_with("test_path") + torch_load.assert_called_once_with("test_path", weights_only=False) @patch("olive.model.handler.pytorch.UserModuleLoader")