Skip to content

Commit 5eeabc2

Browse files
[Bugfix] Fix bnb quantization for models with both HF-format and Mistral-format weights (vllm-project#14950)
1 parent 18551e8 commit 5eeabc2

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

tests/quantization/test_bitsandbytes.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
models_4bit_to_test = [
1717
("facebook/opt-125m", "quantize opt model inflight"),
18+
("mistralai/Mistral-7B-Instruct-v0.3",
19+
"quantize inflight model with both HF and Mistral format weights")
1820
]
1921

2022
models_pre_qaunt_4bit_to_test = [

vllm/model_executor/model_loader/loader.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,7 @@ def _get_weight_files(
762762
model_name_or_path: str,
763763
allowed_patterns: List[str],
764764
revision: Optional[str] = None,
765-
) -> Tuple[List[str], str]:
765+
) -> Tuple[str, List[str], str]:
766766
"""Retrieve weight files. Download the files if necessary.
767767
768768
Return the weight files and the file pattern."""
@@ -773,7 +773,7 @@ def _get_weight_files(
773773
weight_files = glob.glob(
774774
os.path.join(model_name_or_path, pattern))
775775
if weight_files:
776-
return weight_files, pattern
776+
return model_name_or_path, weight_files, pattern
777777
else:
778778
hf_api = HfApi()
779779
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
@@ -787,7 +787,8 @@ def _get_weight_files(
787787
revision,
788788
ignore_patterns=self.load_config.ignore_patterns,
789789
)
790-
return glob.glob(os.path.join(hf_folder, pattern)), pattern
790+
return hf_folder, glob.glob(
791+
os.path.join(hf_folder, pattern)), pattern
791792

792793
raise RuntimeError(
793794
f"No model weights found in: `{model_name_or_path}`")
@@ -798,18 +799,36 @@ def _prepare_weights(self, model_name_or_path: str,
798799

799800
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
800801

801-
hf_weights_files, matched_pattern = self._get_weight_files(
802+
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
802803
model_name_or_path, allowed_patterns, revision)
803804

804-
if matched_pattern != "*.safetensors":
805+
use_safetensors = matched_pattern == "*.safetensors"
806+
is_local = os.path.isdir(model_name_or_path)
807+
index_file = SAFE_WEIGHTS_INDEX_NAME
808+
if use_safetensors:
809+
# For models like Mistral-7B-Instruct-v0.3
810+
# there are both sharded safetensors files and a consolidated
811+
# safetensors file. Using both breaks.
812+
# Here, we download the `model.safetensors.index.json` and filter
813+
# any files not found in the index.
814+
if not is_local:
815+
download_safetensors_index_file_from_hf(
816+
model_name_or_path,
817+
index_file,
818+
self.load_config.download_dir,
819+
revision,
820+
)
821+
hf_weights_files = filter_duplicate_safetensors_files(
822+
hf_weights_files, hf_folder, index_file)
823+
else:
805824
hf_weights_files = filter_files_not_needed_for_inference(
806825
hf_weights_files)
807826

808827
if len(hf_weights_files) == 0:
809828
raise RuntimeError(
810829
f"Cannot find any model weights with `{model_name_or_path}`")
811830

812-
return hf_weights_files, matched_pattern == "*.safetensors"
831+
return hf_weights_files, use_safetensors
813832

814833
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
815834
if use_safetensors:

0 commit comments

Comments
 (0)