1
1
import json
2
2
import logging
3
- from typing import Any , Dict , List , Tuple
3
+ from typing import Any
4
4
5
5
from torch ._logging import trace_structured
6
6
from torch .fx import Graph , Node
11
11
12
12
def create_joint_graph_node_information (
13
13
joint_graph : Graph ,
14
- recomputable_node_info : Dict [str , int ],
15
- ) -> Dict [str , Any ]:
16
- joint_graph_node_information : Dict [str , Any ] = {}
14
+ recomputable_node_info : dict [str , int ],
15
+ ) -> dict [str , Any ]:
16
+ joint_graph_node_information : dict [str , Any ] = {}
17
17
18
18
for i , joint_graph_node in enumerate (joint_graph .nodes ):
19
19
is_recomputable_candidate : bool = (
@@ -22,7 +22,7 @@ def create_joint_graph_node_information(
22
22
tensor_meta = joint_graph_node .meta .get ("tensor_meta" )
23
23
shape = getattr (tensor_meta , "shape" , []) if tensor_meta else []
24
24
25
- node_info : Dict [str , Any ] = {
25
+ node_info : dict [str , Any ] = {
26
26
"index" : i ,
27
27
"name" : joint_graph_node .name ,
28
28
"is_recomputable_candidate" : is_recomputable_candidate ,
@@ -43,8 +43,8 @@ def create_joint_graph_node_information(
43
43
return joint_graph_node_information
44
44
45
45
46
- def create_joint_graph_edges (joint_graph : Graph ) -> List [ Tuple [str , str ]]:
47
- joint_graph_edges : List [ Tuple [str , str ]] = [
46
+ def create_joint_graph_edges (joint_graph : Graph ) -> list [ tuple [str , str ]]:
47
+ joint_graph_edges : list [ tuple [str , str ]] = [
48
48
(inp .name , node .name )
49
49
for node in joint_graph .nodes
50
50
for inp in node .all_input_nodes
@@ -54,17 +54,17 @@ def create_joint_graph_edges(joint_graph: Graph) -> List[Tuple[str, str]]:
54
54
55
55
def create_activation_checkpointing_logging_structure_payload (
56
56
joint_graph : Graph ,
57
- joint_graph_node_information : Dict [str , Any ],
58
- joint_graph_edges : List [ Tuple [str , str ]],
59
- all_recomputable_banned_nodes : List [Node ],
57
+ joint_graph_node_information : dict [str , Any ],
58
+ joint_graph_edges : list [ tuple [str , str ]],
59
+ all_recomputable_banned_nodes : list [Node ],
60
60
expected_runtime : float ,
61
- saved_node_idxs : List [int ],
62
- recomputable_node_idxs : List [int ],
63
- memories_banned_nodes : List [float ],
64
- runtimes_banned_nodes : List [float ],
65
- min_cut_saved_values : List [Node ],
66
- ) -> Dict [str , Any ]:
67
- activation_checkpointing_logging_structure_payload : Dict [str , Any ] = {
61
+ saved_node_idxs : list [int ],
62
+ recomputable_node_idxs : list [int ],
63
+ memories_banned_nodes : list [float ],
64
+ runtimes_banned_nodes : list [float ],
65
+ min_cut_saved_values : list [Node ],
66
+ ) -> dict [str , Any ]:
67
+ activation_checkpointing_logging_structure_payload : dict [str , Any ] = {
68
68
"Joint Graph Size" : len (joint_graph .nodes ),
69
69
"Joint Graph Edges" : {
70
70
"Total" : len (joint_graph_edges ),
@@ -86,15 +86,15 @@ def create_activation_checkpointing_logging_structure_payload(
86
86
87
87
def create_structured_trace_for_min_cut_info (
88
88
joint_graph : Graph ,
89
- all_recomputable_banned_nodes : List [Node ],
90
- saved_node_idxs : List [int ],
91
- recomputable_node_idxs : List [int ],
89
+ all_recomputable_banned_nodes : list [Node ],
90
+ saved_node_idxs : list [int ],
91
+ recomputable_node_idxs : list [int ],
92
92
expected_runtime : float ,
93
- memories_banned_nodes : List [float ],
94
- runtimes_banned_nodes : List [float ],
95
- min_cut_saved_values : List [Node ],
93
+ memories_banned_nodes : list [float ],
94
+ runtimes_banned_nodes : list [float ],
95
+ min_cut_saved_values : list [Node ],
96
96
) -> None :
97
- recomputable_node_info : Dict [str , int ] = {
97
+ recomputable_node_info : dict [str , int ] = {
98
98
node .name : idx for idx , node in enumerate (all_recomputable_banned_nodes )
99
99
}
100
100
joint_graph_node_information = create_joint_graph_node_information (
0 commit comments