Skip to content

Commit dd10588

Browse files
cleanup
1 parent 8140cb2 commit dd10588

File tree

3 files changed

+2
-514
lines changed

3 files changed

+2
-514
lines changed

fastvideo/v1/models/loader/utils.py

-167
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Utilities for selecting and loading models."""
33
import contextlib
4-
from dataclasses import dataclass, field
5-
from typing import Dict, List, Optional, Tuple, Type
6-
74
import torch
8-
import transformers
9-
from torch import nn
10-
from transformers.dynamic_module_utils import get_class_from_dynamic_module
115

12-
from vllm.config import ModelConfig, ModelImpl
136
from fastvideo.v1.logger import init_logger
14-
from vllm.model_executor.models import ModelRegistry
15-
from vllm.model_executor.models.adapters import (as_classification_model,
16-
as_embedding_model,
17-
as_reward_model)
187

198
logger = init_logger(__name__)
209

@@ -26,159 +15,3 @@ def set_default_torch_dtype(dtype: torch.dtype):
2615
torch.set_default_dtype(dtype)
2716
yield
2817
torch.set_default_dtype(old_dtype)
29-
30-
31-
def load_hf_config_from_subdir(model_path: str, component_name: str, trust_remote_code: bool = False, revision: Optional[str] = None):
32-
"""
33-
Load a HuggingFace config from a component subdirectory.
34-
35-
Args:
36-
model_path: Path to the model directory
37-
component_name: Name of the component subdirectory (e.g., "text_encoder", "vae")
38-
trust_remote_code: Whether to trust remote code when loading the config
39-
revision: Optional revision to use when loading from HuggingFace Hub
40-
41-
Returns:
42-
The loaded HuggingFace config
43-
"""
44-
import os
45-
from transformers import AutoConfig
46-
47-
component_path = os.path.join(model_path, component_name)
48-
49-
if not os.path.exists(component_path):
50-
raise ValueError(f"Component directory {component_path} does not exist")
51-
52-
config_path = os.path.join(component_path, "config.json")
53-
if not os.path.exists(config_path):
54-
raise ValueError(f"Config file {config_path} does not exist")
55-
56-
logger.info(f"Loading config from {config_path}")
57-
config = AutoConfig.from_pretrained(
58-
component_path,
59-
trust_remote_code=trust_remote_code,
60-
revision=revision
61-
)
62-
63-
return config
64-
65-
66-
def is_transformers_impl_compatible(
67-
arch: str,
68-
module: Optional[transformers.PreTrainedModel] = None) -> bool:
69-
mod = module or getattr(transformers, arch, None)
70-
if mod is None:
71-
return False
72-
if hasattr(mod, "supports_backend"):
73-
return mod.is_backend_compatible()
74-
else:
75-
return mod._supports_flex_attn
76-
77-
78-
def resolve_transformers_fallback(model_config: ModelConfig,
79-
architectures: list[str]):
80-
for i, arch in enumerate(architectures):
81-
if arch == "TransformersModel":
82-
continue
83-
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
84-
None) or dict()
85-
# Make sure that config class is always initialized before model class,
86-
# otherwise the model class won't be able to access the config class,
87-
# the expected auto_map should have correct order like:
88-
# "auto_map": {
89-
# "AutoConfig": "<your-repo-name>--<config-name>",
90-
# "AutoModel": "<your-repo-name>--<config-name>",
91-
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
92-
# },
93-
auto_modules = {
94-
name: get_class_from_dynamic_module(module, model_config.model)
95-
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
96-
}
97-
custom_model_module = auto_modules.get("AutoModel")
98-
# TODO(Isotr0py): Further clean up these raises.
99-
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
100-
if model_config.model_impl == ModelImpl.TRANSFORMERS:
101-
if not is_transformers_impl_compatible(arch, custom_model_module):
102-
raise ValueError(
103-
f"The Transformers implementation of {arch} is not "
104-
"compatible with vLLM.")
105-
architectures[i] = "TransformersModel"
106-
if model_config.model_impl == ModelImpl.AUTO:
107-
if not is_transformers_impl_compatible(arch, custom_model_module):
108-
raise ValueError(
109-
f"{arch} has no vLLM implementation and the Transformers "
110-
"implementation is not compatible with vLLM.")
111-
logger.warning(
112-
"%s has no vLLM implementation, falling back to Transformers "
113-
"implementation. Some features may not be supported and "
114-
"performance may not be optimal.", arch)
115-
architectures[i] = "TransformersModel"
116-
return architectures
117-
118-
119-
def get_model_architecture(
120-
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
121-
architectures = getattr(model_config.hf_config, "architectures", [])
122-
123-
# Special handling for quantized Mixtral.
124-
# FIXME(woosuk): This is a temporary hack.
125-
mixtral_supported = [
126-
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
127-
]
128-
129-
if (model_config.quantization is not None
130-
and model_config.quantization not in mixtral_supported
131-
and "MixtralForCausalLM" in architectures):
132-
architectures = ["QuantMixtralForCausalLM"]
133-
134-
vllm_supported_archs = ModelRegistry.get_supported_archs()
135-
is_vllm_supported = any(arch in vllm_supported_archs
136-
for arch in architectures)
137-
if (not is_vllm_supported
138-
or model_config.model_impl == ModelImpl.TRANSFORMERS):
139-
architectures = resolve_transformers_fallback(model_config,
140-
architectures)
141-
142-
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
143-
if model_config.task == "embed":
144-
model_cls = as_embedding_model(model_cls)
145-
elif model_config.task == "classify":
146-
model_cls = as_classification_model(model_cls)
147-
elif model_config.task == "reward":
148-
model_cls = as_reward_model(model_cls)
149-
150-
return model_cls, arch
151-
152-
153-
def get_architecture_class_name(model_config: ModelConfig) -> str:
154-
return get_model_architecture(model_config)[1]
155-
156-
157-
@dataclass
158-
class ParamMapping:
159-
"""
160-
A class to handle parameter mapping for model weight loading.
161-
It creates a bidirectional mapping between packed parameters and their
162-
constituent parts.
163-
"""
164-
packed_mapping: Dict[str, List[str]]
165-
inverse_packed_mapping: Dict[str, Tuple[str,
166-
int]] = field(default_factory=dict)
167-
168-
def __post_init__(self):
169-
for packed_name, sub_params in self.packed_mapping.items():
170-
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
171-
if len(sub_params) == 1 and sub_params[0] == packed_name:
172-
continue
173-
for index, param_name in enumerate(sub_params):
174-
self.inverse_packed_mapping[param_name] = (
175-
packed_name,
176-
index,
177-
)
178-
179-
def get_sub_modules(self,
180-
module_name: str) -> Optional[Tuple[str, List[str]]]:
181-
for key, value in self.packed_mapping.items():
182-
if module_name.endswith(key):
183-
return key, value
184-
return None

0 commit comments

Comments
 (0)