1
1
# SPDX-License-Identifier: Apache-2.0
2
2
"""Utilities for selecting and loading models."""
3
3
import contextlib
4
- from dataclasses import dataclass , field
5
- from typing import Dict , List , Optional , Tuple , Type
6
-
7
4
import torch
8
- import transformers
9
- from torch import nn
10
- from transformers .dynamic_module_utils import get_class_from_dynamic_module
11
5
12
- from vllm .config import ModelConfig , ModelImpl
13
6
from fastvideo .v1 .logger import init_logger
14
- from vllm .model_executor .models import ModelRegistry
15
- from vllm .model_executor .models .adapters import (as_classification_model ,
16
- as_embedding_model ,
17
- as_reward_model )
18
7
19
8
logger = init_logger (__name__ )
20
9
@@ -26,159 +15,3 @@ def set_default_torch_dtype(dtype: torch.dtype):
26
15
torch .set_default_dtype (dtype )
27
16
yield
28
17
torch .set_default_dtype (old_dtype )
29
-
30
-
31
- def load_hf_config_from_subdir (model_path : str , component_name : str , trust_remote_code : bool = False , revision : Optional [str ] = None ):
32
- """
33
- Load a HuggingFace config from a component subdirectory.
34
-
35
- Args:
36
- model_path: Path to the model directory
37
- component_name: Name of the component subdirectory (e.g., "text_encoder", "vae")
38
- trust_remote_code: Whether to trust remote code when loading the config
39
- revision: Optional revision to use when loading from HuggingFace Hub
40
-
41
- Returns:
42
- The loaded HuggingFace config
43
- """
44
- import os
45
- from transformers import AutoConfig
46
-
47
- component_path = os .path .join (model_path , component_name )
48
-
49
- if not os .path .exists (component_path ):
50
- raise ValueError (f"Component directory { component_path } does not exist" )
51
-
52
- config_path = os .path .join (component_path , "config.json" )
53
- if not os .path .exists (config_path ):
54
- raise ValueError (f"Config file { config_path } does not exist" )
55
-
56
- logger .info (f"Loading config from { config_path } " )
57
- config = AutoConfig .from_pretrained (
58
- component_path ,
59
- trust_remote_code = trust_remote_code ,
60
- revision = revision
61
- )
62
-
63
- return config
64
-
65
-
66
- def is_transformers_impl_compatible (
67
- arch : str ,
68
- module : Optional [transformers .PreTrainedModel ] = None ) -> bool :
69
- mod = module or getattr (transformers , arch , None )
70
- if mod is None :
71
- return False
72
- if hasattr (mod , "supports_backend" ):
73
- return mod .is_backend_compatible ()
74
- else :
75
- return mod ._supports_flex_attn
76
-
77
-
78
- def resolve_transformers_fallback (model_config : ModelConfig ,
79
- architectures : list [str ]):
80
- for i , arch in enumerate (architectures ):
81
- if arch == "TransformersModel" :
82
- continue
83
- auto_map : dict [str , str ] = getattr (model_config .hf_config , "auto_map" ,
84
- None ) or dict ()
85
- # Make sure that config class is always initialized before model class,
86
- # otherwise the model class won't be able to access the config class,
87
- # the expected auto_map should have correct order like:
88
- # "auto_map": {
89
- # "AutoConfig": "<your-repo-name>--<config-name>",
90
- # "AutoModel": "<your-repo-name>--<config-name>",
91
- # "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
92
- # },
93
- auto_modules = {
94
- name : get_class_from_dynamic_module (module , model_config .model )
95
- for name , module in sorted (auto_map .items (), key = lambda x : x [0 ])
96
- }
97
- custom_model_module = auto_modules .get ("AutoModel" )
98
- # TODO(Isotr0py): Further clean up these raises.
99
- # perhaps handled them in _ModelRegistry._raise_for_unsupported?
100
- if model_config .model_impl == ModelImpl .TRANSFORMERS :
101
- if not is_transformers_impl_compatible (arch , custom_model_module ):
102
- raise ValueError (
103
- f"The Transformers implementation of { arch } is not "
104
- "compatible with vLLM." )
105
- architectures [i ] = "TransformersModel"
106
- if model_config .model_impl == ModelImpl .AUTO :
107
- if not is_transformers_impl_compatible (arch , custom_model_module ):
108
- raise ValueError (
109
- f"{ arch } has no vLLM implementation and the Transformers "
110
- "implementation is not compatible with vLLM." )
111
- logger .warning (
112
- "%s has no vLLM implementation, falling back to Transformers "
113
- "implementation. Some features may not be supported and "
114
- "performance may not be optimal." , arch )
115
- architectures [i ] = "TransformersModel"
116
- return architectures
117
-
118
-
119
- def get_model_architecture (
120
- model_config : ModelConfig ) -> Tuple [Type [nn .Module ], str ]:
121
- architectures = getattr (model_config .hf_config , "architectures" , [])
122
-
123
- # Special handling for quantized Mixtral.
124
- # FIXME(woosuk): This is a temporary hack.
125
- mixtral_supported = [
126
- "fp8" , "compressed-tensors" , "gptq_marlin" , "awq_marlin"
127
- ]
128
-
129
- if (model_config .quantization is not None
130
- and model_config .quantization not in mixtral_supported
131
- and "MixtralForCausalLM" in architectures ):
132
- architectures = ["QuantMixtralForCausalLM" ]
133
-
134
- vllm_supported_archs = ModelRegistry .get_supported_archs ()
135
- is_vllm_supported = any (arch in vllm_supported_archs
136
- for arch in architectures )
137
- if (not is_vllm_supported
138
- or model_config .model_impl == ModelImpl .TRANSFORMERS ):
139
- architectures = resolve_transformers_fallback (model_config ,
140
- architectures )
141
-
142
- model_cls , arch = ModelRegistry .resolve_model_cls (architectures )
143
- if model_config .task == "embed" :
144
- model_cls = as_embedding_model (model_cls )
145
- elif model_config .task == "classify" :
146
- model_cls = as_classification_model (model_cls )
147
- elif model_config .task == "reward" :
148
- model_cls = as_reward_model (model_cls )
149
-
150
- return model_cls , arch
151
-
152
-
153
- def get_architecture_class_name (model_config : ModelConfig ) -> str :
154
- return get_model_architecture (model_config )[1 ]
155
-
156
-
157
- @dataclass
158
- class ParamMapping :
159
- """
160
- A class to handle parameter mapping for model weight loading.
161
- It creates a bidirectional mapping between packed parameters and their
162
- constituent parts.
163
- """
164
- packed_mapping : Dict [str , List [str ]]
165
- inverse_packed_mapping : Dict [str , Tuple [str ,
166
- int ]] = field (default_factory = dict )
167
-
168
- def __post_init__ (self ):
169
- for packed_name , sub_params in self .packed_mapping .items ():
170
- # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
171
- if len (sub_params ) == 1 and sub_params [0 ] == packed_name :
172
- continue
173
- for index , param_name in enumerate (sub_params ):
174
- self .inverse_packed_mapping [param_name ] = (
175
- packed_name ,
176
- index ,
177
- )
178
-
179
- def get_sub_modules (self ,
180
- module_name : str ) -> Optional [Tuple [str , List [str ]]]:
181
- for key , value in self .packed_mapping .items ():
182
- if module_name .endswith (key ):
183
- return key , value
184
- return None
0 commit comments