Skip to content

Commit db4ce78

Browse files
aorenstepytorchmergebot
authored andcommitted
PEP585: More UP006 fixes (pytorch#146392)
This should be the final PR before we can enable RUFF UP006. Pull Request resolved: pytorch#146392 Approved by: https://github.com/justinchuby, https://github.com/albanD, https://github.com/Skylion007
1 parent 76ad19a commit db4ce78

File tree

81 files changed

+283
-329
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+283
-329
lines changed

test/distributed/_composable/fsdp/test_fully_shard_ignore_params.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Owner(s): ["oncall: distributed"]
22

33
import sys
4-
from typing import List
54

65
import torch
76
import torch.distributed as dist
@@ -119,7 +118,7 @@ def _find_name_param_mappings(module: torch.nn.Module, prefix: str):
119118

120119

121120
def _discover_ddp_ignored_params(module: torch.nn.Module, prefix: str):
122-
ddp_ignore_parameters: List[str] = []
121+
ddp_ignore_parameters: list[str] = []
123122
if isinstance(module, FSDP2):
124123
ddp_ignore_parameters = [name for name, _ in module.named_parameters(prefix)]
125124
else:

test/distributed/test_c10d_functional_native.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
import unittest
55
from datetime import timedelta
6-
from typing import List, Optional
6+
from typing import Optional
77

88
import torch
99
import torch.distributed as dist
@@ -576,24 +576,24 @@ def __init__(self) -> None:
576576
self.waits = 0
577577
self.dels = 0
578578

579-
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> dist.Work:
579+
def broadcast(self, tensor_list: list[torch.Tensor], opts: object) -> dist.Work:
580580
return _DummyWork(self)
581581

582582
def allgather_into_tensor_coalesced(
583583
self,
584-
output_lists: List[torch.Tensor],
585-
input_list: List[torch.Tensor],
584+
output_lists: list[torch.Tensor],
585+
input_list: list[torch.Tensor],
586586
opts: object,
587587
) -> dist.Work:
588588
return _DummyWork(self)
589589

590-
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> dist.Work:
590+
def allreduce(self, tensors: list[torch.Tensor], opts: object) -> dist.Work:
591591
return _DummyWork(self)
592592

593593
def reduce_scatter_tensor_coalesced(
594594
self,
595-
outputTensors: List[torch.Tensor],
596-
inputTensors: List[torch.Tensor],
595+
outputTensors: list[torch.Tensor],
596+
inputTensors: list[torch.Tensor],
597597
opts: object,
598598
) -> dist.Work:
599599
return _DummyWork(self)

test/functorch/test_ac_logging.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Owner(s): ["module: functorch"]
2-
from typing import Dict, List, Tuple
32
from unittest.mock import MagicMock, patch
43

54
from torch._functorch._activation_checkpointing.ac_logging_utils import (
@@ -33,17 +32,17 @@ def setUp(self) -> None:
3332

3433
self.graph.nodes = [self.node1, self.node2]
3534

36-
self.all_recomputable_banned_nodes: List[Node] = [self.node1]
37-
self.saved_node_idxs: List[int] = [0]
38-
self.recomputable_node_idxs: List[int] = []
35+
self.all_recomputable_banned_nodes: list[Node] = [self.node1]
36+
self.saved_node_idxs: list[int] = [0]
37+
self.recomputable_node_idxs: list[int] = []
3938
self.expected_runtime: int = 100
40-
self.memories_banned_nodes: List[int] = [50]
41-
self.runtimes_banned_nodes: List[int] = [10]
42-
self.min_cut_saved_values: List[Node] = [self.node1]
39+
self.memories_banned_nodes: list[int] = [50]
40+
self.runtimes_banned_nodes: list[int] = [10]
41+
self.min_cut_saved_values: list[Node] = [self.node1]
4342

4443
def test_create_joint_graph_node_information(self) -> None:
45-
recomputable_node_info: Dict[str, int] = {"node1": 0}
46-
expected_output: Dict[str, Dict] = {
44+
recomputable_node_info: dict[str, int] = {"node1": 0}
45+
expected_output: dict[str, dict] = {
4746
"node1": {
4847
"index": 0,
4948
"name": "node1",
@@ -68,12 +67,12 @@ def test_create_joint_graph_node_information(self) -> None:
6867
self.assertEqual(result, expected_output)
6968

7069
def test_create_joint_graph_edges(self) -> None:
71-
expected_edges: List[Tuple[str, str]] = [("node1", "node2")]
70+
expected_edges: list[tuple[str, str]] = [("node1", "node2")]
7271
result = create_joint_graph_edges(self.graph)
7372
self.assertEqual(result, expected_edges)
7473

7574
def test_create_activation_checkpointing_logging_structure_payload(self) -> None:
76-
input_joint_graph_node_information: Dict[str, Dict] = {
75+
input_joint_graph_node_information: dict[str, dict] = {
7776
"node1": {
7877
"index": 0,
7978
"name": "node1",
@@ -85,8 +84,8 @@ def test_create_activation_checkpointing_logging_structure_payload(self) -> None
8584
"recomputable_candidate_info": {"recomputable_node_idx": 0},
8685
}
8786
}
88-
joint_graph_edges: List[Tuple[str, str]] = [("node1", "node2")]
89-
expected_payload: Dict[str, any] = {
87+
joint_graph_edges: list[tuple[str, str]] = [("node1", "node2")]
88+
expected_payload: dict[str, any] = {
9089
"Joint Graph Size": 2,
9190
"Joint Graph Edges": {"Total": 1, "Edges": joint_graph_edges},
9291
"Joint Graph Node Information": input_joint_graph_node_information,

test/onnx/test_onnxscript_runtime.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime."""
44

5-
from typing import Sequence
5+
from typing import Sequence # noqa: UP035
66

77
import onnx_test_common
88
import onnxscript

test/onnx/torchlib/error_reproduction.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
import time
1111
import traceback
12-
from typing import Any, Mapping
12+
from typing import Any, TYPE_CHECKING
1313

1414
import numpy as np
1515

@@ -20,6 +20,10 @@
2020
import torch
2121

2222

23+
if TYPE_CHECKING:
24+
from collections.abc import Mapping
25+
26+
2327
_REPRODUCTION_TEMPLATE = '''\
2428
import google.protobuf.text_format
2529
import numpy as np

torch/_C/_monitor.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime
44
from enum import Enum
55
from types import TracebackType
6-
from typing import Callable, Optional, Type
6+
from typing import Callable, Optional
77

88
class Aggregation(Enum):
99
VALUE = ...
@@ -48,7 +48,7 @@ class _WaitCounterTracker:
4848
def __enter__(self) -> None: ...
4949
def __exit__(
5050
self,
51-
exec_type: Optional[Type[BaseException]] = None,
51+
exec_type: Optional[type[BaseException]] = None,
5252
exec_value: Optional[BaseException] = None,
5353
traceback: Optional[TracebackType] = None,
5454
) -> None: ...

torch/_decomp/__init__.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,7 @@
44
from collections.abc import Sequence
55
from functools import lru_cache, partial, wraps
66
from itertools import chain
7-
from typing import (
8-
Callable,
9-
Dict,
10-
FrozenSet,
11-
List,
12-
Optional,
13-
Set,
14-
TYPE_CHECKING,
15-
TypeVar,
16-
Union,
17-
)
7+
from typing import Callable, Optional, TYPE_CHECKING, TypeVar, Union
188
from typing_extensions import ParamSpec
199

2010

torch/_dynamo/polyfills/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import types
1212
from collections.abc import MutableMapping, Sequence
1313
from itertools import repeat as _repeat
14-
from typing import Any, Callable, List, TYPE_CHECKING
14+
from typing import Any, Callable, TYPE_CHECKING
1515

1616
import torch
1717

torch/_dynamo/utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
Generic,
6060
Optional,
6161
overload,
62-
Set,
6362
TypeVar,
6463
Union,
6564
)
@@ -1393,7 +1392,7 @@ def default(self, o):
13931392
except Exception:
13941393
return "Value is not JSON serializable"
13951394

1396-
keys_to_scrub: Set[Any] = set()
1395+
keys_to_scrub: set[Any] = set()
13971396
inductor_conf_str = None
13981397
inductor_config_copy = (
13991398
torch._inductor.config.get_config_copy() if torch._inductor.config else None

torch/_dynamo/variables/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
"""
1717

1818
import collections
19+
from collections.abc import Sequence
1920
from enum import Enum
20-
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
21+
from typing import Any, Callable, Optional, TYPE_CHECKING
2122

2223
from .. import variables
2324
from ..current_scope_id import current_scope_id

torch/_dynamo/variables/builtin.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import types
1212
import typing
1313
from collections import defaultdict, OrderedDict
14-
from collections.abc import KeysView
15-
from typing import Callable, Sequence, TYPE_CHECKING, Union
14+
from collections.abc import KeysView, Sequence
15+
from typing import Callable, TYPE_CHECKING, Union
1616

1717
import torch
1818
from torch import sym_float, sym_int

torch/_dynamo/variables/functions.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import sys
3131
import types
3232
from collections.abc import Sequence
33-
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar
33+
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
3434
from typing_extensions import Never
3535
from unittest.mock import patch
3636

@@ -517,7 +517,7 @@ def has_unpack_var_sequence(self, tx):
517517
def has_force_unpack_var_sequence(self, tx) -> builtins.bool:
518518
return True
519519

520-
def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
520+
def force_unpack_var_sequence(self, tx) -> list[VariableTracker]:
521521
result = []
522522
while True:
523523
try:
@@ -547,8 +547,8 @@ def call_method(
547547
self,
548548
tx: "InstructionTranslator",
549549
name: str,
550-
args: "List[VariableTracker]",
551-
kwargs: "Dict[str, VariableTracker]",
550+
args: "list[VariableTracker]",
551+
kwargs: "dict[str, VariableTracker]",
552552
) -> "VariableTracker":
553553
if name == "__next__":
554554
return self.next_variable(tx)

torch/_export/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from contextlib import contextmanager
1717
from functools import lru_cache
1818

19-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19+
from typing import Any, Callable, Optional, Union
2020
from unittest.mock import patch
2121

2222
import torch

torch/_functorch/_activation_checkpointing/ac_logging_utils.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import logging
3-
from typing import Any, Dict, List, Tuple
3+
from typing import Any
44

55
from torch._logging import trace_structured
66
from torch.fx import Graph, Node
@@ -11,9 +11,9 @@
1111

1212
def create_joint_graph_node_information(
1313
joint_graph: Graph,
14-
recomputable_node_info: Dict[str, int],
15-
) -> Dict[str, Any]:
16-
joint_graph_node_information: Dict[str, Any] = {}
14+
recomputable_node_info: dict[str, int],
15+
) -> dict[str, Any]:
16+
joint_graph_node_information: dict[str, Any] = {}
1717

1818
for i, joint_graph_node in enumerate(joint_graph.nodes):
1919
is_recomputable_candidate: bool = (
@@ -22,7 +22,7 @@ def create_joint_graph_node_information(
2222
tensor_meta = joint_graph_node.meta.get("tensor_meta")
2323
shape = getattr(tensor_meta, "shape", []) if tensor_meta else []
2424

25-
node_info: Dict[str, Any] = {
25+
node_info: dict[str, Any] = {
2626
"index": i,
2727
"name": joint_graph_node.name,
2828
"is_recomputable_candidate": is_recomputable_candidate,
@@ -43,8 +43,8 @@ def create_joint_graph_node_information(
4343
return joint_graph_node_information
4444

4545

46-
def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]:
47-
joint_graph_edges: List[Tuple[str, str]] = [
46+
def create_joint_graph_edges(joint_graph: Graph) -> list[tuple[str, str]]:
47+
joint_graph_edges: list[tuple[str, str]] = [
4848
(inp.name, node.name)
4949
for node in joint_graph.nodes
5050
for inp in node.all_input_nodes
@@ -54,17 +54,17 @@ def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]:
5454

5555
def create_activation_checkpointing_logging_structure_payload(
5656
joint_graph: Graph,
57-
joint_graph_node_information: Dict[str, Any],
58-
joint_graph_edges: List[Tuple[str, str]],
59-
all_recomputable_banned_nodes: List[Node],
57+
joint_graph_node_information: dict[str, Any],
58+
joint_graph_edges: list[tuple[str, str]],
59+
all_recomputable_banned_nodes: list[Node],
6060
expected_runtime: float,
61-
saved_node_idxs: List[int],
62-
recomputable_node_idxs: List[int],
63-
memories_banned_nodes: List[float],
64-
runtimes_banned_nodes: List[float],
65-
min_cut_saved_values: List[Node],
66-
) -> Dict[str, Any]:
67-
activation_checkpointing_logging_structure_payload: Dict[str, Any] = {
61+
saved_node_idxs: list[int],
62+
recomputable_node_idxs: list[int],
63+
memories_banned_nodes: list[float],
64+
runtimes_banned_nodes: list[float],
65+
min_cut_saved_values: list[Node],
66+
) -> dict[str, Any]:
67+
activation_checkpointing_logging_structure_payload: dict[str, Any] = {
6868
"Joint Graph Size": len(joint_graph.nodes),
6969
"Joint Graph Edges": {
7070
"Total": len(joint_graph_edges),
@@ -86,15 +86,15 @@ def create_activation_checkpointing_logging_structure_payload(
8686

8787
def create_structured_trace_for_min_cut_info(
8888
joint_graph: Graph,
89-
all_recomputable_banned_nodes: List[Node],
90-
saved_node_idxs: List[int],
91-
recomputable_node_idxs: List[int],
89+
all_recomputable_banned_nodes: list[Node],
90+
saved_node_idxs: list[int],
91+
recomputable_node_idxs: list[int],
9292
expected_runtime: float,
93-
memories_banned_nodes: List[float],
94-
runtimes_banned_nodes: List[float],
95-
min_cut_saved_values: List[Node],
93+
memories_banned_nodes: list[float],
94+
runtimes_banned_nodes: list[float],
95+
min_cut_saved_values: list[Node],
9696
) -> None:
97-
recomputable_node_info: Dict[str, int] = {
97+
recomputable_node_info: dict[str, int] = {
9898
node.name: idx for idx, node in enumerate(all_recomputable_banned_nodes)
9999
}
100100
joint_graph_node_information = create_joint_graph_node_information(

torch/_functorch/_aot_autograd/subclass_parametrization.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import dataclasses
22
import itertools
3-
from typing import Any, Iterable, Union
3+
from collections.abc import Iterable
4+
from typing import Any, Union
45

56
import torch
67
from torch.utils._python_dispatch import is_traceable_wrapper_subclass

0 commit comments

Comments
 (0)