Skip to content

Commit 567cb46

Browse files
[float8nocompile] Add alternate Triton kernels for FP8 conversion which use atomic_max-based algo instead of reduction-based algo (#1455)
* refactor float8nocompile kernel so autotune is easily usable * refactor to make kernel algo configurable; refactor unit tests to test both algos * address comments
1 parent eab345c commit 567cb46

File tree

3 files changed

+189
-48
lines changed

3 files changed

+189
-48
lines changed

torchao/prototype/float8nocompile/benchmark/benchmark.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5959
def get_configs() -> List[ExperimentConfig]:
6060
layer_sizes = [[4096, 4096]]
6161
input_shapes = [(2**4, 4096), (2**8, 4096), (2**12, 4096), (2**16, 4096)]
62-
high_precision_dtypes = [torch.float32, torch.bfloat16]
62+
high_precision_dtypes = [torch.bfloat16]
6363
configs = []
6464
for layer_size, input_shape, high_precision_dtype in itertools.product(
6565
layer_sizes, input_shapes, high_precision_dtypes
@@ -133,18 +133,20 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
133133

134134
def print_results(experiments: List[Experiment]):
135135
headers = [
136-
"input_size",
136+
"input_shape",
137137
"high_precision_dtype",
138138
"eager_time",
139139
"compiled_time",
140140
"float8nocompile",
141141
]
142142
rows = []
143143
for experiment in experiments:
144-
input_size = experiment.config.input_shape[0] * experiment.config.input_shape[1]
144+
input_shape = (
145+
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
146+
)
145147
rows.append(
146148
[
147-
f"{input_size:.2e}",
149+
input_shape,
148150
experiment.config.high_precision_dtype,
149151
experiment.result.eager_time,
150152
experiment.result.compiled_time,

torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py

+170-42
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
"""
88
Triton kernels for scaling high precision tensors to float8.
99
"""
10+
from enum import Enum
1011

1112
import torch
12-
1313
import triton
1414
import triton.language as tl
1515

@@ -31,8 +31,99 @@
3131
}
3232

3333

34+
class KernelAlgorithm(Enum):
35+
"""Enum for FP8 conversion strategy."""
36+
37+
# use atomic max to compute global amax between blocks
38+
ATOMIC_MAX = "atomic_max"
39+
40+
# reduce shared buffer containing local block amaxes to find global amax
41+
REDUCTION = "reduction"
42+
43+
44+
kernel_configs = [
45+
triton.Config({"BLOCK_SIZE": 128}, num_warps=1),
46+
triton.Config({"BLOCK_SIZE": 256}, num_warps=2),
47+
triton.Config({"BLOCK_SIZE": 512}, num_warps=4),
48+
]
49+
50+
51+
# --- atomic max version of kernel ---
52+
@triton.autotune(configs=kernel_configs, key=["input_size"])
53+
@triton.jit
54+
def _block_amax_atomic(
55+
input_ptr,
56+
amax_ptr,
57+
num_elements,
58+
input_dtype: tl.constexpr,
59+
BLOCK_SIZE: tl.constexpr,
60+
EPS: tl.constexpr,
61+
):
62+
# compute local amax for each block
63+
block_id = tl.program_id(axis=0)
64+
block_start = block_id * BLOCK_SIZE
65+
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
66+
block_mask = block_offs < num_elements
67+
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
68+
block_amax = tl.max(tl.abs(vals))
69+
tl.atomic_max(amax_ptr, block_amax)
70+
71+
72+
@triton.jit
73+
def _fp8_scale_atomic(
74+
amax_ptr,
75+
scale_out_ptr,
76+
fp8_dtype_max,
77+
EPS: tl.constexpr,
78+
):
79+
# load previously computed global amax
80+
global_amax = tl.load(amax_ptr)
81+
82+
# compute scale, must be fp32
83+
scale = (fp8_dtype_max / tl.clamp(global_amax, min=EPS, max=float("inf"))).to(
84+
tl.float32
85+
)
86+
87+
# store scale for use in Float8Tensor constructor
88+
scale_off = tl.arange(0, 1)
89+
tl.store(scale_out_ptr + scale_off, scale)
90+
91+
92+
@triton.autotune(configs=kernel_configs, key=["input_size"])
3493
@triton.jit
35-
def _block_amax(
94+
def _to_fp8_atomic(
95+
input_ptr,
96+
scale_ptr,
97+
amax_ptr,
98+
out_ptr,
99+
num_elements,
100+
fp8_dtype_min,
101+
fp8_dtype_max,
102+
input_dtype: tl.constexpr,
103+
output_dtype: tl.constexpr,
104+
BLOCK_SIZE: tl.constexpr,
105+
EPS: tl.constexpr,
106+
):
107+
block_id = tl.program_id(axis=0)
108+
109+
# load scale
110+
scale = tl.load(scale_ptr)
111+
112+
# load block of input tensor
113+
block_start = block_id * BLOCK_SIZE
114+
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
115+
mask = block_offs < num_elements
116+
vals = tl.load(input_ptr + block_offs, mask=mask).to(input_dtype)
117+
118+
# perform conversion
119+
vals = vals * scale
120+
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)
121+
tl.store(out_ptr + block_offs, fp8_vals, mask=mask)
122+
123+
124+
# --- reduction version of kernel ---
125+
@triton.jit
126+
def _block_amax_reduction(
36127
input_ptr,
37128
block_amaxes_ptr,
38129
num_elements,
@@ -46,12 +137,12 @@ def _block_amax(
46137
block_offs = block_start + tl.arange(0, BLOCK_SIZE)
47138
block_mask = block_offs < num_elements
48139
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
49-
block_amax = tl.max(tl.abs(vals), axis=0)
140+
block_amax = tl.max(tl.abs(vals))
50141
tl.store(block_amaxes_ptr + block_id, block_amax)
51142

52143

53144
@triton.jit
54-
def _fp8_scale(
145+
def _fp8_scale_reduction(
55146
block_amaxes_ptr,
56147
scale_out_ptr,
57148
num_elements,
@@ -75,7 +166,7 @@ def _fp8_scale(
75166

76167

77168
@triton.jit
78-
def _to_fp8(
169+
def _to_fp8_reduction(
79170
input_ptr,
80171
scale_ptr,
81172
out_ptr,
@@ -108,12 +199,10 @@ def triton_hp_tensor_to_float8_dynamic(
108199
fp8_dtype: torch.dtype,
109200
linear_mm_config: LinearMMConfig,
110201
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
202+
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
111203
) -> Float8Tensor:
112-
113204
assert hp_tensor.is_contiguous(), "tensor must be contiguous"
114205

115-
BLOCK_SIZE = 8 # TODO(danielvegamyhre): tune this for perf
116-
117206
num_elements = hp_tensor.numel()
118207
orig_shape = hp_tensor.shape
119208
flattened_input = hp_tensor.flatten()
@@ -126,47 +215,86 @@ def triton_hp_tensor_to_float8_dynamic(
126215

127216
# allocate memory for computed scale, local block maxes, and output fp8 tensor
128217
scale_out = torch.empty((1,), dtype=torch.float32, device=hp_tensor.device)
129-
block_amaxes = torch.zeros(
130-
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
131-
)
218+
132219
fp8_output = torch.empty_like(
133220
flattened_input, dtype=fp8_dtype, device=hp_tensor.device
134221
)
135222

136-
# compute local amax for each block
137223
grid = lambda meta: (triton.cdiv(num_elements, meta["BLOCK_SIZE"]),)
138-
_block_amax[grid](
139-
flattened_input,
140-
block_amaxes,
141-
num_elements,
142-
input_dtype=tl_input_dtype,
143-
BLOCK_SIZE=BLOCK_SIZE,
144-
EPS=EPS,
145-
)
146224

147-
# calculate global amax across all blocks and use it to compute scale
148-
_fp8_scale[(1, 1, 1)](
149-
block_amaxes,
150-
scale_out,
151-
num_elements,
152-
fp8_dtype_max,
153-
BLOCK_SIZE=BLOCK_SIZE,
154-
EPS=EPS,
155-
)
225+
if algo == KernelAlgorithm.ATOMIC_MAX:
226+
global_amax = torch.zeros((1,), dtype=torch.float32, device=hp_tensor.device)
227+
# compute global amax to be used for scaling
228+
_block_amax_atomic[grid](
229+
flattened_input,
230+
global_amax,
231+
num_elements,
232+
input_dtype=tl_input_dtype,
233+
EPS=EPS,
234+
)
156235

157-
# perform conversion
158-
_to_fp8[grid](
159-
flattened_input,
160-
scale_out,
161-
fp8_output,
162-
num_elements,
163-
fp8_dtype_min,
164-
fp8_dtype_max,
165-
input_dtype=tl_input_dtype,
166-
output_dtype=tl_output_dtype,
167-
BLOCK_SIZE=BLOCK_SIZE,
168-
EPS=EPS,
169-
)
236+
# compute scale for fp8 conversion
237+
_fp8_scale_atomic[1, 1, 1](
238+
global_amax,
239+
scale_out,
240+
fp8_dtype_max,
241+
EPS=EPS,
242+
)
243+
244+
# perform conversion and store scale for use in Float8Tensor
245+
_to_fp8_atomic[grid](
246+
flattened_input,
247+
scale_out,
248+
global_amax,
249+
fp8_output,
250+
num_elements,
251+
fp8_dtype_min,
252+
fp8_dtype_max,
253+
input_dtype=tl_input_dtype,
254+
output_dtype=tl_output_dtype,
255+
EPS=EPS,
256+
)
257+
elif algo == KernelAlgorithm.REDUCTION:
258+
max_block_size = 512
259+
BLOCK_SIZE = min(max_block_size, num_elements)
260+
block_amaxes = torch.zeros(
261+
(num_elements // BLOCK_SIZE,), dtype=torch.float32, device=hp_tensor.device
262+
)
263+
# compute local amax for each block
264+
_block_amax_reduction[grid](
265+
flattened_input,
266+
block_amaxes,
267+
num_elements,
268+
input_dtype=tl_input_dtype,
269+
BLOCK_SIZE=BLOCK_SIZE,
270+
EPS=EPS,
271+
)
272+
273+
# calculate global amax across all blocks and use it to compute scale
274+
_fp8_scale_reduction[(1, 1, 1)](
275+
block_amaxes,
276+
scale_out,
277+
num_elements,
278+
fp8_dtype_max,
279+
BLOCK_SIZE=BLOCK_SIZE,
280+
EPS=EPS,
281+
)
282+
283+
# perform conversion
284+
_to_fp8_reduction[grid](
285+
flattened_input,
286+
scale_out,
287+
fp8_output,
288+
num_elements,
289+
fp8_dtype_min,
290+
fp8_dtype_max,
291+
input_dtype=tl_input_dtype,
292+
output_dtype=tl_output_dtype,
293+
BLOCK_SIZE=BLOCK_SIZE,
294+
EPS=EPS,
295+
)
296+
else:
297+
raise ValueError(f"Unsupported kernel algorithm: {algo}")
170298

171299
return Float8Tensor(
172300
fp8_output.reshape(orig_shape),

torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise_test.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,24 @@
33
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
44
from torchao.float8.float8_tensor import LinearMMConfig
55
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
6+
KernelAlgorithm,
67
triton_hp_tensor_to_float8_dynamic,
78
)
89

910

10-
def test_fp8_triton_hp_tensor_to_float8_dynamic():
11+
@pytest.mark.parametrize(
12+
"algo", [KernelAlgorithm.ATOMIC_MAX, KernelAlgorithm.REDUCTION]
13+
)
14+
@pytest.mark.parametrize(
15+
"input_shape",
16+
[(32, 32), (512, 512), (4096, 4096)],
17+
)
18+
def test_fp8_triton_hp_tensor_to_float8_dynamic(
19+
algo: KernelAlgorithm, input_shape: tuple[int, int]
20+
):
1121
assert torch.cuda.is_available()
1222
device = "cuda"
13-
input_bf16 = torch.randn((4, 4), dtype=torch.bfloat16, device=device)
23+
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
1424
x_bf16 = input_bf16.clone().detach().to(device)
1525
y_bf16 = input_bf16.clone().detach().to(device)
1626

@@ -26,6 +36,7 @@ def test_fp8_triton_hp_tensor_to_float8_dynamic():
2636
y_bf16,
2737
torch.float8_e4m3fn,
2838
LinearMMConfig(),
39+
algo=algo,
2940
)
3041

3142
def allclose_fp8(tensor1, tensor2, atol=1e-3, rtol=1e-3):

0 commit comments

Comments
 (0)