|
8 | 8 | from packaging.version import Version
|
9 | 9 | from torch.distributed import ReduceOp
|
10 | 10 |
|
| 11 | +from .deep_gemm_utils import deepgemm_fp8_gemm, per_block_cast_to_fp8, per_token_cast_to_fp8 |
11 | 12 | from .fp8_config import dynamic_kernel
|
12 | 13 |
|
13 | 14 | SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
@@ -699,17 +700,11 @@ def all_gather_fp8_lagacy(
|
699 | 700 | ret = cur_buffer[SCALE_BYTES:].view(fp8_type)
|
700 | 701 | ret, scale = cast_to_fp8(input_.view(-1), fp8_format=fp8_format, out=ret)
|
701 | 702 | cur_buffer[:SCALE_BYTES].view(torch.float)[0] = scale
|
702 |
| - # cur_buffer[:SCALE_BYTES] = scale.unsqueeze(0).view(torch.uint8) |
703 | 703 | dist.all_gather(combined_buffers, cur_buffer, group=group, async_op=async_op)
|
704 | 704 | for out, buf in zip(output_list, combined_buffers):
|
705 | 705 | scale = buf[:SCALE_BYTES].clone().view(scale.dtype)
|
706 | 706 | output = buf[SCALE_BYTES:].view(fp8_type)
|
707 | 707 | cast_from_fp8(output.view(shape), scale, input_type, out=out)
|
708 |
| - # output = combined_buffer.view(world_size, -1)[:, SCALE_BYTES:].view(fp8_type) |
709 |
| - # scales = combined_buffer.view(world_size, -1)[:, :SCALE_BYTES].view(torch.float) |
710 |
| - # output = output.float() * scales |
711 |
| - # for i, out in enumerate(output_list): |
712 |
| - # out.copy_(output[i].view(shape)) |
713 | 708 |
|
714 | 709 |
|
715 | 710 | @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=False, disable=cuda_arch < 89)
|
@@ -834,6 +829,50 @@ def backward(ctx: Any, out_grad) -> Any:
|
834 | 829 | return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
835 | 830 |
|
836 | 831 |
|
| 832 | +class _LinearFp8DeepGemm(torch.autograd.Function): |
| 833 | + """ |
| 834 | + Behave similar to torch.nn.functional.linear |
| 835 | + """ |
| 836 | + |
| 837 | + def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: |
| 838 | + if not (x.dim() == 2 and w.dim() == 2): |
| 839 | + raise ValueError("Batched fp8 deep_gemm is not supported") |
| 840 | + # x: (m, k), w: (n, k) |
| 841 | + # x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k)) |
| 842 | + (m, k), (n, _) = x.shape, w.shape |
| 843 | + x_per_tok, w_per_blk = per_token_cast_to_fp8(x), per_block_cast_to_fp8(w) |
| 844 | + |
| 845 | + out = torch.empty((m, n), dtype=torch.bfloat16, device=x.device) # NOTE: DeepGemm only supports bf16 output |
| 846 | + deepgemm_fp8_gemm(x_per_tok, w_per_blk, out) |
| 847 | + |
| 848 | + ctx.w_t_per_plk = per_block_cast_to_fp8(w.t()) |
| 849 | + ctx.x_t_per_blk = per_block_cast_to_fp8(x.t()) |
| 850 | + ctx.mnk = (m, n, k) |
| 851 | + return out |
| 852 | + |
| 853 | + def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 854 | + # o_grad: (m, n) |
| 855 | + # x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n)) |
| 856 | + # w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m)) |
| 857 | + m, n, k = ctx.mnk |
| 858 | + o_per_tok = per_token_cast_to_fp8(o_grad) |
| 859 | + |
| 860 | + x_grad = torch.empty((m, k), dtype=torch.bfloat16, device=o_grad.device) |
| 861 | + deepgemm_fp8_gemm(o_per_tok, ctx.w_t_per_plk, x_grad) |
| 862 | + |
| 863 | + o_grad_t_per_tok = per_token_cast_to_fp8(o_grad.t()) |
| 864 | + w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device) |
| 865 | + deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad) |
| 866 | + |
| 867 | + return x_grad, w_grad |
| 868 | + |
| 869 | + |
| 870 | +def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor: |
| 871 | + if bias is not None: |
| 872 | + raise ValueError("bias is not supported in deep_gemm") |
| 873 | + return _LinearFp8DeepGemm.apply(input, weight) |
| 874 | + |
| 875 | + |
837 | 876 | @torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
|
838 | 877 | def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
839 | 878 | return _LinearFp8.apply(input, weight, bias)
|
|
0 commit comments