Skip to content

Commit b424728

Browse files
committed
Allow to compute when bsz == 1
1 parent 16e46ef commit b424728

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

colossalai/quantization/fp8.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,11 @@ def forward(
800800
scale_a=inv_scale_x,
801801
scale_b=inv_scale_w,
802802
use_fast_accum=True,
803-
)[0]
803+
)
804+
805+
if isinstance(out, tuple):
806+
out = out[0]
807+
804808
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
805809

806810
@staticmethod
@@ -814,15 +818,23 @@ def backward(ctx: Any, out_grad) -> Any:
814818
scale_a=out_grad_scale,
815819
scale_b=ctx.inv_scale_w,
816820
use_fast_accum=True,
817-
)[0]
821+
)
822+
823+
if isinstance(x_grad, tuple):
824+
x_grad = x_grad[0]
825+
818826
w_grad = torch._scaled_mm(
819827
out_grad_fp8.t().contiguous(),
820828
ctx.x_fp8.t().contiguous().t(),
821829
out_dtype=ctx.out_dtype,
822830
scale_a=out_grad_scale,
823831
scale_b=ctx.inv_scale_x,
824832
use_fast_accum=True,
825-
)[0]
833+
)
834+
835+
if isinstance(w_grad, tuple):
836+
w_grad = w_grad[0]
837+
826838
bias_grad = None
827839
if ctx.has_bias:
828840
bias_grad = out_grad.sum(0)
@@ -835,8 +847,14 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
835847
"""
836848

837849
def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
838-
if not (x.dim() == 2 and w.dim() == 2):
839-
raise ValueError("Batched fp8 deep_gemm is not supported")
850+
has_batch_dim = False
851+
if x.dim() == 3:
852+
has_batch_dim = True
853+
if x.size(1) != 1:
854+
raise ValueError(f"Batched fp8 deep_gemm is not supported, found x shape: {x.shape}")
855+
x = x.squeeze(1)
856+
ctx.has_batch_dim = has_batch_dim
857+
840858
# x: (m, k), w: (n, k)
841859
# x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k))
842860
(m, k), (n, _) = x.shape, w.shape
@@ -848,12 +866,17 @@ def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
848866
ctx.w_t_per_plk = per_block_cast_to_fp8(w.t())
849867
ctx.x_t_per_blk = per_block_cast_to_fp8(x.t())
850868
ctx.mnk = (m, n, k)
869+
if has_batch_dim:
870+
out = out.unsqueeze(1)
851871
return out
852872

853873
def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
854874
# o_grad: (m, n)
855875
# x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n))
856876
# w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m))
877+
if ctx.has_batch_dim:
878+
o_grad = o_grad.squeeze(1)
879+
857880
m, n, k = ctx.mnk
858881
o_per_tok = per_token_cast_to_fp8(o_grad)
859882

@@ -864,6 +887,9 @@ def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
864887
w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device)
865888
deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad)
866889

890+
if ctx.has_batch_dim:
891+
x_grad = x_grad.unsqueeze(1)
892+
867893
return x_grad, w_grad
868894

869895

0 commit comments

Comments
 (0)