@@ -762,7 +762,7 @@ def _get_weight_files(
762
762
model_name_or_path : str ,
763
763
allowed_patterns : List [str ],
764
764
revision : Optional [str ] = None ,
765
- ) -> Tuple [List [str ], str ]:
765
+ ) -> Tuple [str , List [str ], str ]:
766
766
"""Retrieve weight files. Download the files if necessary.
767
767
768
768
Return the weight files and the file pattern."""
@@ -773,7 +773,7 @@ def _get_weight_files(
773
773
weight_files = glob .glob (
774
774
os .path .join (model_name_or_path , pattern ))
775
775
if weight_files :
776
- return weight_files , pattern
776
+ return model_name_or_path , weight_files , pattern
777
777
else :
778
778
hf_api = HfApi ()
779
779
repo_files = hf_api .list_repo_files (repo_id = model_name_or_path )
@@ -787,7 +787,8 @@ def _get_weight_files(
787
787
revision ,
788
788
ignore_patterns = self .load_config .ignore_patterns ,
789
789
)
790
- return glob .glob (os .path .join (hf_folder , pattern )), pattern
790
+ return hf_folder , glob .glob (
791
+ os .path .join (hf_folder , pattern )), pattern
791
792
792
793
raise RuntimeError (
793
794
f"No model weights found in: `{ model_name_or_path } `" )
@@ -798,18 +799,36 @@ def _prepare_weights(self, model_name_or_path: str,
798
799
799
800
allowed_patterns = ["*.safetensors" , "*.bin" , "*.pt" ]
800
801
801
- hf_weights_files , matched_pattern = self ._get_weight_files (
802
+ hf_folder , hf_weights_files , matched_pattern = self ._get_weight_files (
802
803
model_name_or_path , allowed_patterns , revision )
803
804
804
- if matched_pattern != "*.safetensors" :
805
+ use_safetensors = matched_pattern == "*.safetensors"
806
+ is_local = os .path .isdir (model_name_or_path )
807
+ index_file = SAFE_WEIGHTS_INDEX_NAME
808
+ if use_safetensors :
809
+ # For models like Mistral-7B-Instruct-v0.3
810
+ # there are both sharded safetensors files and a consolidated
811
+ # safetensors file. Using both breaks.
812
+ # Here, we download the `model.safetensors.index.json` and filter
813
+ # any files not found in the index.
814
+ if not is_local :
815
+ download_safetensors_index_file_from_hf (
816
+ model_name_or_path ,
817
+ index_file ,
818
+ self .load_config .download_dir ,
819
+ revision ,
820
+ )
821
+ hf_weights_files = filter_duplicate_safetensors_files (
822
+ hf_weights_files , hf_folder , index_file )
823
+ else :
805
824
hf_weights_files = filter_files_not_needed_for_inference (
806
825
hf_weights_files )
807
826
808
827
if len (hf_weights_files ) == 0 :
809
828
raise RuntimeError (
810
829
f"Cannot find any model weights with `{ model_name_or_path } `" )
811
830
812
- return hf_weights_files , matched_pattern == "*.safetensors"
831
+ return hf_weights_files , use_safetensors
813
832
814
833
def _hf_weight_iter (self , hf_weights_files , use_safetensors : bool ):
815
834
if use_safetensors :
0 commit comments