Skip to content

Commit c832927

Browse files
committed
[deepgemm] adapt deepgemm
1 parent 44d4053 commit c832927

File tree

3 files changed

+169
-6
lines changed

3 files changed

+169
-6
lines changed
+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# This file was modifed from https://github.com/deepseek-ai/DeepGEMM
2+
# as the utils are not included in library
3+
# Thanks for developing and open-sourcing the performant kernel
4+
5+
# Original LICENSE:
6+
7+
# MIT License
8+
9+
# Copyright (c) 2025 DeepSeek
10+
11+
# Permission is hereby granted, free of charge, to any person obtaining a copy
12+
# of this software and associated documentation files (the "Software"), to deal
13+
# in the Software without restriction, including without limitation the rights
14+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15+
# copies of the Software, and to permit persons to whom the Software is
16+
# furnished to do so, subject to the following conditions:
17+
18+
# The above copyright notice and this permission notice shall be included in all
19+
# copies or substantial portions of the Software.
20+
21+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27+
# SOFTWARE.
28+
29+
import warnings
30+
from typing import Tuple
31+
32+
import torch
33+
34+
__WARNING_MSG = "Couldn't find deep_gemm library, please install from https://github.com/deepseek-ai/DeepGEMM and run corresponding tests"
35+
try:
36+
from deep_gemm import ceil_div, gemm_fp8_fp8_bf16_nt
37+
38+
IS_DEEP_GEMM_AVAIL = True
39+
except ImportError:
40+
IS_DEEP_GEMM_AVAIL = False
41+
warnings.warn(__WARNING_MSG)
42+
43+
def ceil_dev(*args, **kwargs): # to surpass code lint
44+
raise NotImplementedError(__WARNING_MSG)
45+
46+
def gemm_fp8_fp8_bf16_nt(*args, **kwargs):
47+
raise NotImplementedError(__WARNING_MSG)
48+
49+
50+
def deepgemm_fp8_gemm(
51+
lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor
52+
) -> None:
53+
gemm_fp8_fp8_bf16_nt(lhs, rhs, out)
54+
55+
56+
# TODO: There seems to be better kernel implemented in triton
57+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
58+
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
59+
"""
60+
Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in token-wise mannar
61+
Args:
62+
x (`torch.Tensor`):
63+
Matmul x in x @ y.t(), where x.shape() is (m, k)
64+
65+
Returns:
66+
`Tuple[torch.Tensor, torch.Tensor]`: x_float8_e4m3fn and scaler
67+
"""
68+
assert x.dim() == 2 and x.size(1) % 128 == 0
69+
m, n = x.shape
70+
x_view = x.view(m, -1, 128)
71+
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
72+
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
73+
74+
75+
# TODO: There seems to be better kernel implemented in triton
76+
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
77+
def per_block_cast_to_fp8(y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
78+
"""
79+
Casting input tensor to float8_e4m3fn percicision and cooresponding scaler in block-wise mannar
80+
Args:
81+
y (`torch.Tensor`):
82+
Matmul y in x @ y.t(), where y.shape() is (n, k)
83+
84+
Returns:
85+
`Tuple[torch.Tensor, torch.Tensor]`: y_float8_e4m3fn and scaler
86+
"""
87+
assert y.dim() == 2
88+
m, n = y.shape
89+
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=y.dtype, device=y.device)
90+
x_padded[:m, :n] = y
91+
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
92+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
93+
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
94+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

colossalai/quantization/fp8.py

+45-6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from packaging.version import Version
99
from torch.distributed import ReduceOp
1010

11+
from .deep_gemm_utils import deepgemm_fp8_gemm, per_block_cast_to_fp8, per_token_cast_to_fp8
1112
from .fp8_config import dynamic_kernel
1213

1314
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
@@ -699,17 +700,11 @@ def all_gather_fp8_lagacy(
699700
ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
700701
ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
701702
cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
702-
# cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8)
703703
dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
704704
for out, buf in zip(output_list, combined_buffers):
705705
scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
706706
output = buf[SCALE_BYTES:].view(fp8_type)
707707
cast_from_fp8(output.view(shape), scale, input_type, out=out)
708-
# output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type)
709-
# scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float)
710-
# output = output.float() * scales
711-
# for i, out in enumerate(output_list):
712-
# out.copy_(output[i].view(shape))
713708

714709

715710
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
@@ -834,6 +829,50 @@ def backward(ctx: Any, out_grad) -> Any:
834829
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
835830

836831

832+
class _LinearFp8DeepGemm(torch.autograd.Function):
833+
"""
834+
Behave similar to torch.nn.functional.linear
835+
"""
836+
837+
def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
838+
if not (x.dim() == 2 and w.dim() == 2):
839+
raise ValueError("Batched fp8 deep_gemm is not supported")
840+
# x: (m, k), w: (n, k)
841+
# x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k))
842+
(m, k), (n, _) = x.shape, w.shape
843+
x_per_tok, w_per_blk = per_token_cast_to_fp8(x), per_block_cast_to_fp8(w)
844+
845+
out = torch.empty((m, n), dtype=torch.bfloat16, device=x.device) # NOTE: DeepGemm only supports bf16 output
846+
deepgemm_fp8_gemm(x_per_tok, w_per_blk, out)
847+
848+
ctx.w_t_per_plk = per_block_cast_to_fp8(w.t())
849+
ctx.x_t_per_blk = per_block_cast_to_fp8(x.t())
850+
ctx.mnk = (m, n, k)
851+
return out
852+
853+
def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
854+
# o_grad: (m, n)
855+
# x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n))
856+
# w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m))
857+
m, n, k = ctx.mnk
858+
o_per_tok = per_token_cast_to_fp8(o_grad)
859+
860+
x_grad = torch.empty((m, k), dtype=torch.bfloat16, device=o_grad.device)
861+
deepgemm_fp8_gemm(o_per_tok, ctx.w_t_per_plk, x_grad)
862+
863+
o_grad_t_per_tok = per_token_cast_to_fp8(o_grad.t())
864+
w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device)
865+
deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad)
866+
867+
return x_grad, w_grad
868+
869+
870+
def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor:
871+
if bias is not None:
872+
raise ValueError("bias is not supported in deep_gemm")
873+
return _LinearFp8DeepGemm.apply(input, weight)
874+
875+
837876
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
838877
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
839878
return _LinearFp8.apply(input, weight, bias)

tests/test_fp8/test_fp8_deepgemm.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
from torch.testing import assert_close
5+
6+
from colossalai.accelerator import get_accelerator
7+
from colossalai.quantization.fp8 import linear_fp8_deep_gemm
8+
from colossalai.utils import get_current_device
9+
10+
m, k, n = 128, 384, 256
11+
DTYPE = torch.bfloat16
12+
13+
14+
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
15+
def test_fp8_linear():
16+
# create tensors
17+
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
18+
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
19+
ref_w = w.clone().detach().requires_grad_()
20+
ref_x = x.clone().detach().requires_grad_()
21+
22+
out = linear_fp8_deep_gemm(x, w)
23+
assert out.shape == x.shape[:-1] + (n,)
24+
out.sum().backward()
25+
ref_out = F.linear(ref_x, ref_w)
26+
ref_out.sum().backward()
27+
28+
assert_close(out, ref_out, rtol=0.2, atol=0.1)
29+
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
30+
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)

0 commit comments

Comments
 (0)