Skip to content

Commit cd588b3

Browse files
committed
Use pyupgrade --py39-plus to improve code
1 parent 6cc9c8d commit cd588b3

20 files changed

+507
-524
lines changed

src/transformers/cache_utils.py

+70-70
Large diffs are not rendered by default.

src/transformers/dynamic_module_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_relative_imports(module_file: Union[str, os.PathLike]) -> list[str]:
9090
module_file (`str` or `os.PathLike`): The module file to inspect.
9191
9292
Returns:
93-
`List[str]`: The list of relative imports in the module.
93+
`list[str]`: The list of relative imports in the module.
9494
"""
9595
with open(module_file, encoding="utf-8") as f:
9696
content = f.read()
@@ -112,7 +112,7 @@ def get_relative_import_files(module_file: Union[str, os.PathLike]) -> list[str]
112112
module_file (`str` or `os.PathLike`): The module file to inspect.
113113
114114
Returns:
115-
`List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
115+
`list[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
116116
of module files a given module needs.
117117
"""
118118
no_change = False
@@ -144,7 +144,7 @@ def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
144144
filename (`str` or `os.PathLike`): The module file to inspect.
145145
146146
Returns:
147-
`List[str]`: The list of all packages required to use the input module.
147+
`list[str]`: The list of all packages required to use the input module.
148148
"""
149149
with open(filename, encoding="utf-8") as f:
150150
content = f.read()
@@ -175,7 +175,7 @@ def check_imports(filename: Union[str, os.PathLike]) -> list[str]:
175175
filename (`str` or `os.PathLike`): The module file to check.
176176
177177
Returns:
178-
`List[str]`: The list of relative imports in the file.
178+
`list[str]`: The list of relative imports in the file.
179179
"""
180180
imports = get_imports(filename)
181181
missing_packages = []

src/transformers/model_debugging_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# Copyright 2025 The HuggingFace Inc. team.
32
# All rights reserved.
43
#

src/transformers/modeling_outputs.py

+145-145
Large diffs are not rendered by default.

src/transformers/modeling_utils.py

+75-76
Large diffs are not rendered by default.

src/transformers/tokenization_utils_base.py

+86-87
Large diffs are not rendered by default.

src/transformers/utils/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32

43
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
54
#
@@ -16,7 +15,6 @@
1615
# limitations under the License.
1716

1817
from functools import lru_cache
19-
from typing import FrozenSet
2018

2119
from huggingface_hub import get_full_repo_name # for backward compatibility
2220
from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
@@ -299,8 +297,8 @@ def check_min_version(min_version):
299297
)
300298

301299

302-
@lru_cache()
303-
def get_available_devices() -> FrozenSet[str]:
300+
@lru_cache
301+
def get_available_devices() -> frozenset[str]:
304302
"""
305303
Returns a frozenset of devices available for the current PyTorch installation.
306304
"""

src/transformers/utils/attention_visualizer.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# Copyright 2025 The HuggingFace Inc. team.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");

src/transformers/utils/backbone_utils.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# Copyright 2023 The HuggingFace Inc. team.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,7 +16,8 @@
1716

1817
import enum
1918
import inspect
20-
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
19+
from collections.abc import Iterable
20+
from typing import TYPE_CHECKING, Optional, Union
2121

2222

2323
if TYPE_CHECKING:
@@ -75,9 +75,9 @@ def verify_out_features_out_indices(
7575

7676

7777
def _align_output_features_output_indices(
78-
out_features: Optional[List[str]],
79-
out_indices: Optional[Union[List[int], Tuple[int]]],
80-
stage_names: List[str],
78+
out_features: Optional[list[str]],
79+
out_indices: Optional[Union[list[int], tuple[int]]],
80+
stage_names: list[str],
8181
):
8282
"""
8383
Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
@@ -106,10 +106,10 @@ def _align_output_features_output_indices(
106106

107107

108108
def get_aligned_output_features_output_indices(
109-
out_features: Optional[List[str]],
110-
out_indices: Optional[Union[List[int], Tuple[int]]],
111-
stage_names: List[str],
112-
) -> Tuple[List[str], List[int]]:
109+
out_features: Optional[list[str]],
110+
out_indices: Optional[Union[list[int], tuple[int]]],
111+
stage_names: list[str],
112+
) -> tuple[list[str], list[int]]:
113113
"""
114114
Get the `out_features` and `out_indices` so that they are aligned.
115115
@@ -198,7 +198,7 @@ def out_features(self):
198198
return self._out_features
199199

200200
@out_features.setter
201-
def out_features(self, out_features: List[str]):
201+
def out_features(self, out_features: list[str]):
202202
"""
203203
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
204204
"""
@@ -211,7 +211,7 @@ def out_indices(self):
211211
return self._out_indices
212212

213213
@out_indices.setter
214-
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
214+
def out_indices(self, out_indices: Union[tuple[int], list[int]]):
215215
"""
216216
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
217217
"""
@@ -264,7 +264,7 @@ def out_features(self):
264264
return self._out_features
265265

266266
@out_features.setter
267-
def out_features(self, out_features: List[str]):
267+
def out_features(self, out_features: list[str]):
268268
"""
269269
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
270270
"""
@@ -277,7 +277,7 @@ def out_indices(self):
277277
return self._out_indices
278278

279279
@out_indices.setter
280-
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
280+
def out_indices(self, out_indices: Union[tuple[int], list[int]]):
281281
"""
282282
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
283283
"""

src/transformers/utils/chat_template_utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from contextlib import contextmanager
2020
from datetime import datetime
2121
from functools import lru_cache
22-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
22+
from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
2323

2424
from packaging import version
2525

@@ -71,7 +71,7 @@ class DocstringParsingException(Exception):
7171
pass
7272

7373

74-
def _get_json_schema_type(param_type: str) -> Dict[str, str]:
74+
def _get_json_schema_type(param_type: str) -> dict[str, str]:
7575
type_mapping = {
7676
int: {"type": "integer"},
7777
float: {"type": "number"},
@@ -87,7 +87,7 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]:
8787
return type_mapping.get(param_type, {"type": "object"})
8888

8989

90-
def _parse_type_hint(hint: str) -> Dict:
90+
def _parse_type_hint(hint: str) -> dict:
9191
origin = get_origin(hint)
9292
args = get_args(hint)
9393

@@ -152,7 +152,7 @@ def _parse_type_hint(hint: str) -> Dict:
152152
raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
153153

154154

155-
def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
155+
def _convert_type_hints_to_json_schema(func: Callable) -> dict:
156156
type_hints = get_type_hints(func)
157157
signature = inspect.signature(func)
158158
required = []
@@ -173,7 +173,7 @@ def _convert_type_hints_to_json_schema(func: Callable) -> Dict:
173173
return schema
174174

175175

176-
def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Optional[Dict], Optional[str]]:
176+
def parse_google_format_docstring(docstring: str) -> tuple[Optional[str], Optional[dict], Optional[str]]:
177177
"""
178178
Parses a Google-style docstring to extract the function description,
179179
argument descriptions, and return description.
@@ -206,7 +206,7 @@ def parse_google_format_docstring(docstring: str) -> Tuple[Optional[str], Option
206206
return description, args_dict, returns
207207

208208

209-
def get_json_schema(func: Callable) -> Dict:
209+
def get_json_schema(func: Callable) -> dict:
210210
"""
211211
This function generates a JSON schema for a given function, based on its docstring and type hints. This is
212212
mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
@@ -398,7 +398,7 @@ def is_active(self) -> bool:
398398
return self._rendered_blocks or self._generation_indices
399399

400400
@contextmanager
401-
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
401+
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
402402
try:
403403
if self.is_active():
404404
raise ValueError("AssistantTracker should not be reused before closed")

src/transformers/utils/fx.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# Copyright 2021 The HuggingFace Team. All rights reserved.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -24,7 +23,7 @@
2423
import random
2524
import sys
2625
import warnings
27-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
26+
from typing import Any, Callable, Literal, Optional, Union
2827

2928
import torch
3029
import torch.utils._pytree as pytree
@@ -78,9 +77,9 @@
7877

7978

8079
def _generate_supported_model_class_names(
81-
model_name: Type[PretrainedConfig],
82-
supported_tasks: Optional[Union[str, List[str]]] = None,
83-
) -> List[str]:
80+
model_name: type[PretrainedConfig],
81+
supported_tasks: Optional[Union[str, list[str]]] = None,
82+
) -> list[str]:
8483
task_mapping = {
8584
"default": MODEL_MAPPING_NAMES,
8685
"pretraining": MODEL_FOR_PRETRAINING_MAPPING_NAMES,
@@ -588,7 +587,7 @@ def to_concrete(t):
588587
return operator.getitem(a, b)
589588

590589

591-
_MANUAL_META_OVERRIDES: Dict[Callable, Callable] = {
590+
_MANUAL_META_OVERRIDES: dict[Callable, Callable] = {
592591
torch.nn.Embedding: torch_nn_embedding,
593592
torch.nn.functional.embedding: torch_nn_functional_embedding,
594593
torch.nn.LayerNorm: torch_nn_layernorm,
@@ -714,7 +713,7 @@ class HFCacheProxy(HFProxy):
714713
Proxy that represents an instance of `transformers.cache_utils.Cache`.
715714
"""
716715

717-
def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]):
716+
def install_orig_cache_cls(self, orig_cache_cls: type[Cache]):
718717
self._orig_cache_cls = orig_cache_cls
719718

720719
@property
@@ -768,8 +767,8 @@ class HFProxyableClassMeta(type):
768767
def __new__(
769768
cls,
770769
name: str,
771-
bases: Tuple[Type, ...],
772-
attrs: Dict[str, Any],
770+
bases: tuple[type, ...],
771+
attrs: dict[str, Any],
773772
proxy_factory_fn: Optional[Callable[[Node], Proxy]] = None,
774773
):
775774
cls = super().__new__(cls, name, bases, attrs)
@@ -792,7 +791,7 @@ def __new__(
792791
return cls
793792

794793

795-
def gen_constructor_wrapper(target: Callable) -> Tuple[Callable, Callable]:
794+
def gen_constructor_wrapper(target: Callable) -> tuple[Callable, Callable]:
796795
"""
797796
Wraps `target` to be proxyable. Used for tensor creators like `torch.ones`, `torch.arange` and so on.
798797
"""
@@ -811,7 +810,7 @@ def _proxies_to_metas(v):
811810
return v
812811

813812

814-
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
813+
def create_cache_proxy_factory_fn(orig_cache_cls: type[Cache]) -> Callable[[Node], HFCacheProxy]:
815814
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
816815
global _CURRENT_TRACER
817816
if not isinstance(_CURRENT_TRACER, HFTracer):
@@ -847,7 +846,7 @@ def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
847846
)
848847

849848

850-
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
849+
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[list[int]] = None):
851850
if forbidden_values is None:
852851
forbidden_values = []
853852
value = random.randint(low, high)
@@ -897,8 +896,8 @@ def __init__(self, autowrap_modules=(math,), autowrap_functions=()):
897896
)
898897

899898
def _generate_dummy_input(
900-
self, model: "PreTrainedModel", input_name: str, shape: List[int], input_names: List[str]
901-
) -> Dict[str, torch.Tensor]:
899+
self, model: "PreTrainedModel", input_name: str, shape: list[int], input_names: list[str]
900+
) -> dict[str, torch.Tensor]:
902901
"""Generates dummy input for model inference recording."""
903902
# Retrieving the model class, either from the "class_for_deserialization" attribute if the model was restored
904903
# from pickle, or from the "__class__" attribute in the general case.
@@ -1179,7 +1178,7 @@ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cac
11791178
return attr_val
11801179

11811180
# Needed for PyTorch 1.13+
1182-
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]):
1181+
def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]):
11831182
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
11841183

11851184
def call_module(self, m, forward, args, kwargs):
@@ -1231,8 +1230,8 @@ def patch_for_tracing(self, root: Union[torch.nn.Module, Callable[..., Any]]):
12311230
def trace(
12321231
self,
12331232
root: Union[torch.nn.Module, Callable[..., Any]],
1234-
concrete_args: Optional[Dict[str, Any]] = None,
1235-
dummy_inputs: Optional[Dict[str, Any]] = None,
1233+
concrete_args: Optional[dict[str, Any]] = None,
1234+
dummy_inputs: Optional[dict[str, Any]] = None,
12361235
complete_concrete_args_with_inputs_not_in_dummy_inputs: bool = True,
12371236
) -> Graph:
12381237
"""
@@ -1420,7 +1419,7 @@ def keys(self, obj: "Proxy") -> Any:
14201419
return attribute
14211420

14221421

1423-
def get_concrete_args(model: nn.Module, input_names: List[str]):
1422+
def get_concrete_args(model: nn.Module, input_names: list[str]):
14241423
sig = inspect.signature(model.forward)
14251424

14261425
if not (set(input_names) <= set(sig.parameters.keys())):
@@ -1448,9 +1447,9 @@ def check_if_model_is_supported(model: "PreTrainedModel"):
14481447

14491448
def symbolic_trace(
14501449
model: "PreTrainedModel",
1451-
input_names: Optional[List[str]] = None,
1450+
input_names: Optional[list[str]] = None,
14521451
disable_check: bool = False,
1453-
tracer_cls: Type[HFTracer] = HFTracer,
1452+
tracer_cls: type[HFTracer] = HFTracer,
14541453
) -> GraphModule:
14551454
"""
14561455
Performs symbolic tracing on the model.

0 commit comments

Comments
 (0)