7
7
"""
8
8
Triton kernels for scaling high precision tensors to float8.
9
9
"""
10
+ from enum import Enum
10
11
11
12
import torch
12
-
13
13
import triton
14
14
import triton .language as tl
15
15
31
31
}
32
32
33
33
34
+ class KernelAlgorithm (Enum ):
35
+ """Enum for FP8 conversion strategy."""
36
+
37
+ # use atomic max to compute global amax between blocks
38
+ ATOMIC_MAX = "atomic_max"
39
+
40
+ # reduce shared buffer containing local block amaxes to find global amax
41
+ REDUCTION = "reduction"
42
+
43
+
44
+ kernel_configs = [
45
+ triton .Config ({"BLOCK_SIZE" : 128 }, num_warps = 1 ),
46
+ triton .Config ({"BLOCK_SIZE" : 256 }, num_warps = 2 ),
47
+ triton .Config ({"BLOCK_SIZE" : 512 }, num_warps = 4 ),
48
+ ]
49
+
50
+
51
+ # --- atomic max version of kernel ---
52
+ @triton .autotune (configs = kernel_configs , key = ["input_size" ])
53
+ @triton .jit
54
+ def _block_amax_atomic (
55
+ input_ptr ,
56
+ amax_ptr ,
57
+ num_elements ,
58
+ input_dtype : tl .constexpr ,
59
+ BLOCK_SIZE : tl .constexpr ,
60
+ EPS : tl .constexpr ,
61
+ ):
62
+ # compute local amax for each block
63
+ block_id = tl .program_id (axis = 0 )
64
+ block_start = block_id * BLOCK_SIZE
65
+ block_offs = block_start + tl .arange (0 , BLOCK_SIZE )
66
+ block_mask = block_offs < num_elements
67
+ vals = tl .load (input_ptr + block_offs , mask = block_mask ).to (input_dtype )
68
+ block_amax = tl .max (tl .abs (vals ))
69
+ tl .atomic_max (amax_ptr , block_amax )
70
+
71
+
72
+ @triton .jit
73
+ def _fp8_scale_atomic (
74
+ amax_ptr ,
75
+ scale_out_ptr ,
76
+ fp8_dtype_max ,
77
+ EPS : tl .constexpr ,
78
+ ):
79
+ # load previously computed global amax
80
+ global_amax = tl .load (amax_ptr )
81
+
82
+ # compute scale, must be fp32
83
+ scale = (fp8_dtype_max / tl .clamp (global_amax , min = EPS , max = float ("inf" ))).to (
84
+ tl .float32
85
+ )
86
+
87
+ # store scale for use in Float8Tensor constructor
88
+ scale_off = tl .arange (0 , 1 )
89
+ tl .store (scale_out_ptr + scale_off , scale )
90
+
91
+
92
+ @triton .autotune (configs = kernel_configs , key = ["input_size" ])
34
93
@triton .jit
35
- def _block_amax (
94
+ def _to_fp8_atomic (
95
+ input_ptr ,
96
+ scale_ptr ,
97
+ amax_ptr ,
98
+ out_ptr ,
99
+ num_elements ,
100
+ fp8_dtype_min ,
101
+ fp8_dtype_max ,
102
+ input_dtype : tl .constexpr ,
103
+ output_dtype : tl .constexpr ,
104
+ BLOCK_SIZE : tl .constexpr ,
105
+ EPS : tl .constexpr ,
106
+ ):
107
+ block_id = tl .program_id (axis = 0 )
108
+
109
+ # load scale
110
+ scale = tl .load (scale_ptr )
111
+
112
+ # load block of input tensor
113
+ block_start = block_id * BLOCK_SIZE
114
+ block_offs = block_start + tl .arange (0 , BLOCK_SIZE )
115
+ mask = block_offs < num_elements
116
+ vals = tl .load (input_ptr + block_offs , mask = mask ).to (input_dtype )
117
+
118
+ # perform conversion
119
+ vals = vals * scale
120
+ fp8_vals = tl .clamp (vals , min = fp8_dtype_min , max = fp8_dtype_max ).to (output_dtype )
121
+ tl .store (out_ptr + block_offs , fp8_vals , mask = mask )
122
+
123
+
124
+ # --- reduction version of kernel ---
125
+ @triton .jit
126
+ def _block_amax_reduction (
36
127
input_ptr ,
37
128
block_amaxes_ptr ,
38
129
num_elements ,
@@ -46,12 +137,12 @@ def _block_amax(
46
137
block_offs = block_start + tl .arange (0 , BLOCK_SIZE )
47
138
block_mask = block_offs < num_elements
48
139
vals = tl .load (input_ptr + block_offs , mask = block_mask ).to (input_dtype )
49
- block_amax = tl .max (tl .abs (vals ), axis = 0 )
140
+ block_amax = tl .max (tl .abs (vals ))
50
141
tl .store (block_amaxes_ptr + block_id , block_amax )
51
142
52
143
53
144
@triton .jit
54
- def _fp8_scale (
145
+ def _fp8_scale_reduction (
55
146
block_amaxes_ptr ,
56
147
scale_out_ptr ,
57
148
num_elements ,
@@ -75,7 +166,7 @@ def _fp8_scale(
75
166
76
167
77
168
@triton .jit
78
- def _to_fp8 (
169
+ def _to_fp8_reduction (
79
170
input_ptr ,
80
171
scale_ptr ,
81
172
out_ptr ,
@@ -108,12 +199,10 @@ def triton_hp_tensor_to_float8_dynamic(
108
199
fp8_dtype : torch .dtype ,
109
200
linear_mm_config : LinearMMConfig ,
110
201
gemm_input_role : GemmInputRole = GemmInputRole .INPUT ,
202
+ algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
111
203
) -> Float8Tensor :
112
-
113
204
assert hp_tensor .is_contiguous (), "tensor must be contiguous"
114
205
115
- BLOCK_SIZE = 8 # TODO(danielvegamyhre): tune this for perf
116
-
117
206
num_elements = hp_tensor .numel ()
118
207
orig_shape = hp_tensor .shape
119
208
flattened_input = hp_tensor .flatten ()
@@ -126,47 +215,86 @@ def triton_hp_tensor_to_float8_dynamic(
126
215
127
216
# allocate memory for computed scale, local block maxes, and output fp8 tensor
128
217
scale_out = torch .empty ((1 ,), dtype = torch .float32 , device = hp_tensor .device )
129
- block_amaxes = torch .zeros (
130
- (num_elements // BLOCK_SIZE ,), dtype = torch .float32 , device = hp_tensor .device
131
- )
218
+
132
219
fp8_output = torch .empty_like (
133
220
flattened_input , dtype = fp8_dtype , device = hp_tensor .device
134
221
)
135
222
136
- # compute local amax for each block
137
223
grid = lambda meta : (triton .cdiv (num_elements , meta ["BLOCK_SIZE" ]),)
138
- _block_amax [grid ](
139
- flattened_input ,
140
- block_amaxes ,
141
- num_elements ,
142
- input_dtype = tl_input_dtype ,
143
- BLOCK_SIZE = BLOCK_SIZE ,
144
- EPS = EPS ,
145
- )
146
224
147
- # calculate global amax across all blocks and use it to compute scale
148
- _fp8_scale [(1 , 1 , 1 )](
149
- block_amaxes ,
150
- scale_out ,
151
- num_elements ,
152
- fp8_dtype_max ,
153
- BLOCK_SIZE = BLOCK_SIZE ,
154
- EPS = EPS ,
155
- )
225
+ if algo == KernelAlgorithm .ATOMIC_MAX :
226
+ global_amax = torch .zeros ((1 ,), dtype = torch .float32 , device = hp_tensor .device )
227
+ # compute global amax to be used for scaling
228
+ _block_amax_atomic [grid ](
229
+ flattened_input ,
230
+ global_amax ,
231
+ num_elements ,
232
+ input_dtype = tl_input_dtype ,
233
+ EPS = EPS ,
234
+ )
156
235
157
- # perform conversion
158
- _to_fp8 [grid ](
159
- flattened_input ,
160
- scale_out ,
161
- fp8_output ,
162
- num_elements ,
163
- fp8_dtype_min ,
164
- fp8_dtype_max ,
165
- input_dtype = tl_input_dtype ,
166
- output_dtype = tl_output_dtype ,
167
- BLOCK_SIZE = BLOCK_SIZE ,
168
- EPS = EPS ,
169
- )
236
+ # compute scale for fp8 conversion
237
+ _fp8_scale_atomic [1 , 1 , 1 ](
238
+ global_amax ,
239
+ scale_out ,
240
+ fp8_dtype_max ,
241
+ EPS = EPS ,
242
+ )
243
+
244
+ # perform conversion and store scale for use in Float8Tensor
245
+ _to_fp8_atomic [grid ](
246
+ flattened_input ,
247
+ scale_out ,
248
+ global_amax ,
249
+ fp8_output ,
250
+ num_elements ,
251
+ fp8_dtype_min ,
252
+ fp8_dtype_max ,
253
+ input_dtype = tl_input_dtype ,
254
+ output_dtype = tl_output_dtype ,
255
+ EPS = EPS ,
256
+ )
257
+ elif algo == KernelAlgorithm .REDUCTION :
258
+ max_block_size = 512
259
+ BLOCK_SIZE = min (max_block_size , num_elements )
260
+ block_amaxes = torch .zeros (
261
+ (num_elements // BLOCK_SIZE ,), dtype = torch .float32 , device = hp_tensor .device
262
+ )
263
+ # compute local amax for each block
264
+ _block_amax_reduction [grid ](
265
+ flattened_input ,
266
+ block_amaxes ,
267
+ num_elements ,
268
+ input_dtype = tl_input_dtype ,
269
+ BLOCK_SIZE = BLOCK_SIZE ,
270
+ EPS = EPS ,
271
+ )
272
+
273
+ # calculate global amax across all blocks and use it to compute scale
274
+ _fp8_scale_reduction [(1 , 1 , 1 )](
275
+ block_amaxes ,
276
+ scale_out ,
277
+ num_elements ,
278
+ fp8_dtype_max ,
279
+ BLOCK_SIZE = BLOCK_SIZE ,
280
+ EPS = EPS ,
281
+ )
282
+
283
+ # perform conversion
284
+ _to_fp8_reduction [grid ](
285
+ flattened_input ,
286
+ scale_out ,
287
+ fp8_output ,
288
+ num_elements ,
289
+ fp8_dtype_min ,
290
+ fp8_dtype_max ,
291
+ input_dtype = tl_input_dtype ,
292
+ output_dtype = tl_output_dtype ,
293
+ BLOCK_SIZE = BLOCK_SIZE ,
294
+ EPS = EPS ,
295
+ )
296
+ else :
297
+ raise ValueError (f"Unsupported kernel algorithm: { algo } " )
170
298
171
299
return Float8Tensor (
172
300
fp8_output .reshape (orig_shape ),
0 commit comments