1
- from collections import defaultdict
2
1
import functools
3
2
import lzma
4
3
import pathlib
5
4
import typing
6
5
7
6
7
+ def _host_memory_space (inst ):
8
+ return inst .shape .layout .memory_space == 5
9
+
10
+
8
11
class StackFrame (typing .NamedTuple ):
9
12
column : int
10
13
file : str
@@ -25,6 +28,35 @@ def __init__(self, wrapped_hlo_proto, proto):
25
28
# proto representing the actual collective, which will be different if the
26
29
# async launch is handled by an async-start op
27
30
# TODO: can any of copy-start, custom-call, recv, send represent communication?
31
+ # This also aims to identify, and (for now) flag as communication, kernels that
32
+ # implement device-to-host and host-to-device copies for memory offloading.
33
+ # For example, a device-to-host offload might look like
34
+ # computation {
35
+ # ...
36
+ # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0:S(5)} dynamic-update-slice(...)
37
+ # }
38
+ # async_computation {
39
+ # ...
40
+ # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0:S(5)} fusion(...), calls=computation
41
+ # }
42
+ # start = (...) async-start(...), calls=async_computation
43
+ # where the :S(5) annotation shows that a buffer is in host memory.
44
+ # A host-to-device load might look like
45
+ # computation {
46
+ # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
47
+ # ...
48
+ # ROOT r1 = bf16[2,8,128,2048]{3,2,1,0} dynamic-slice(param_0, ...)
49
+ # }
50
+ # async_computation {
51
+ # param_0 = bf16[2,8,128,2048]{3,2,1,0:S(5)} parameter(0)
52
+ # ...
53
+ # ROOT r2 = bf16[2,8,128,2048]{3,2,1,0} fusion(param_0, ...), calls=computation
54
+ # }
55
+ # start = (...) async-start(...), calls=async_computation
56
+ # where the :S(5) memory space annotation is in a parameter instead of in the
57
+ # return value.
58
+ # For now, handling host-device kernels as single-device "collective"
59
+ # communication should be sufficient.
28
60
self ._comm_proto = None
29
61
comm_opcodes = {
30
62
"all-gather" ,
@@ -39,25 +71,50 @@ def __init__(self, wrapped_hlo_proto, proto):
39
71
"all-reduce-start" ,
40
72
"collective-permute-start" ,
41
73
}
74
+
75
+ def _is_offloading_instruction (inst ):
76
+ host_dest = _host_memory_space (inst )
77
+
78
+ def _host_operand (i ):
79
+ _ , op = wrapped_hlo_proto .find_instruction_by_id (inst .operand_ids [i ])
80
+ return _host_memory_space (op .proto ())
81
+
82
+ if inst .opcode == "dynamic-slice" and host_dest != _host_operand (0 ):
83
+ return True
84
+ elif (
85
+ inst .opcode == "dynamic-update-slice"
86
+ and host_dest == _host_operand (0 )
87
+ and host_dest != _host_operand (1 )
88
+ ):
89
+ return True
90
+ return False
91
+
42
92
if self ._proto .opcode in comm_opcodes | comm_start_opcodes :
43
93
self ._comm_proto = self ._proto
44
- elif self ._proto .opcode == "async-start" :
94
+ elif self ._proto .opcode in {"async-start" , "fusion" }:
95
+ # fusion example:
96
+ # computation {
97
+ # param_0 = f32[...]{...:S(5)} parameter(0)
98
+ # ...
99
+ # ROOT dus = f32[...]{...:S(5)} dynamic-update-slice(param_0, ...)
100
+ # }
101
+ # inst = f32[256,128,128]{2,1,0:S(5)} fusion(...), calls=computation
45
102
# This might be thinly wrapping an opcode in `comm_opcodes`
46
- other_opcodes = defaultdict (int )
47
- for called_id in self ._proto .called_computation_ids :
48
- for called_inst in wrapped_hlo_proto .find_computation (
49
- called_id
50
- ).instructions :
51
- if called_inst .opcode in comm_opcodes :
103
+ def _visit_computation (computation_id ):
104
+ computation = wrapped_hlo_proto .find_computation (computation_id )
105
+ for called_inst in computation .instructions :
106
+ for called_id in called_inst .called_computation_ids :
107
+ _visit_computation (called_id )
108
+ if called_inst .opcode in comm_opcodes or _is_offloading_instruction (
109
+ called_inst
110
+ ):
52
111
assert (
53
112
self ._comm_proto is None
54
113
), f"Found { called_inst .opcode } child having already found { self ._comm_proto .opcode } "
55
114
self ._comm_proto = called_inst
56
- else :
57
- other_opcodes [called_inst .opcode ] += 1
58
- assert (
59
- other_opcodes .keys () == {"parameter" }
60
- ), f"async-start op { self ._proto .name } wrapped too many opcode types ({ dict (other_opcodes )} ) in addition to { self ._comm_proto } "
115
+
116
+ for called_id in self ._proto .called_computation_ids :
117
+ _visit_computation (called_id )
61
118
62
119
def communication_proto (self ):
63
120
return self ._comm_proto
@@ -68,12 +125,7 @@ def is_communication(self) -> bool:
68
125
a little more complicated than you might hope, because async communications are
69
126
not handled uniformly.
70
127
"""
71
- if self ._comm_proto is None :
72
- return False
73
- assert (
74
- self ._comm_proto .channel_id != 0
75
- ), f"Got channel_id={ self ._comm_proto .channel_id } for { self ._comm_proto .name } "
76
- return True
128
+ return self ._comm_proto is not None
77
129
78
130
def proto (self ):
79
131
"""
0 commit comments