Skip to content

Commit

Permalink
Fix issues with CI builds
Browse files Browse the repository at this point in the history
* Relax version of bitsandbytes
* Add triton to requirements
* Few fixes for using newer version of torch
  • Loading branch information
shaahji committed Jan 31, 2025
1 parent d98186d commit e1e2810
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion olive/model/handler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def load_model(self, rank: int = None, cache_model: bool = True) -> "torch.nn.Mo
elif self.model_file_format == ModelFileFormat.PYTORCH_TORCH_SCRIPT:
model = torch.jit.load(self.model_path)
elif self.model_file_format == ModelFileFormat.PYTORCH_ENTIRE_MODEL:
model = torch.load(self.model_path)
model = torch.load(self.model_path, weights_only=False)
elif self.model_file_format == ModelFileFormat.PYTORCH_SLICE_GPT_MODEL:
model = self._load_slicegpt_model()
elif self.model_file_format == ModelFileFormat.PYTORCH_STATE_DICT:
Expand Down
4 changes: 2 additions & 2 deletions test/requirements-test-gpu.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-r requirements-test.txt
auto-gptq
autoawq
# only available on Linux currently
bitsandbytes==0.43.3
bitsandbytes
triton
2 changes: 1 addition & 1 deletion test/unit_test/model/test_pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_load_from_path(torch_load):
model = PyTorchModelHandler(model_path="test_path")

assert model.load_model() == "dummy_pytorch_model"
torch_load.assert_called_once_with("test_path")
torch_load.assert_called_once_with("test_path", weights_only=False)


@patch("olive.model.handler.pytorch.UserModuleLoader")
Expand Down

0 comments on commit e1e2810

Please sign in to comment.