@@ -800,7 +800,11 @@ def forward(
800
800
scale_a = inv_scale_x ,
801
801
scale_b = inv_scale_w ,
802
802
use_fast_accum = True ,
803
- )[0 ]
803
+ )
804
+
805
+ if isinstance (out , tuple ):
806
+ out = out [0 ]
807
+
804
808
return out .reshape (* ctx .x_shape [:- 1 ], w .shape [0 ])
805
809
806
810
@staticmethod
@@ -814,15 +818,23 @@ def backward(ctx: Any, out_grad) -> Any:
814
818
scale_a = out_grad_scale ,
815
819
scale_b = ctx .inv_scale_w ,
816
820
use_fast_accum = True ,
817
- )[0 ]
821
+ )
822
+
823
+ if isinstance (x_grad , tuple ):
824
+ x_grad = x_grad [0 ]
825
+
818
826
w_grad = torch ._scaled_mm (
819
827
out_grad_fp8 .t ().contiguous (),
820
828
ctx .x_fp8 .t ().contiguous ().t (),
821
829
out_dtype = ctx .out_dtype ,
822
830
scale_a = out_grad_scale ,
823
831
scale_b = ctx .inv_scale_x ,
824
832
use_fast_accum = True ,
825
- )[0 ]
833
+ )
834
+
835
+ if isinstance (w_grad , tuple ):
836
+ w_grad = w_grad [0 ]
837
+
826
838
bias_grad = None
827
839
if ctx .has_bias :
828
840
bias_grad = out_grad .sum (0 )
@@ -835,8 +847,14 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
835
847
"""
836
848
837
849
def forward (ctx : Any , x : torch .Tensor , w : torch .Tensor ) -> torch .Tensor :
838
- if not (x .dim () == 2 and w .dim () == 2 ):
839
- raise ValueError ("Batched fp8 deep_gemm is not supported" )
850
+ has_batch_dim = False
851
+ if x .dim () == 3 :
852
+ has_batch_dim = True
853
+ if x .size (1 ) != 1 :
854
+ raise ValueError (f"Batched fp8 deep_gemm is not supported, found x shape: { x .shape } " )
855
+ x = x .squeeze (1 )
856
+ ctx .has_batch_dim = has_batch_dim
857
+
840
858
# x: (m, k), w: (n, k)
841
859
# x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k))
842
860
(m , k ), (n , _ ) = x .shape , w .shape
@@ -848,12 +866,17 @@ def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
848
866
ctx .w_t_per_plk = per_block_cast_to_fp8 (w .t ())
849
867
ctx .x_t_per_blk = per_block_cast_to_fp8 (x .t ())
850
868
ctx .mnk = (m , n , k )
869
+ if has_batch_dim :
870
+ out = out .unsqueeze (1 )
851
871
return out
852
872
853
873
def backward (ctx : Any , o_grad : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
854
874
# o_grad: (m, n)
855
875
# x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n))
856
876
# w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m))
877
+ if ctx .has_batch_dim :
878
+ o_grad = o_grad .squeeze (1 )
879
+
857
880
m , n , k = ctx .mnk
858
881
o_per_tok = per_token_cast_to_fp8 (o_grad )
859
882
@@ -864,6 +887,9 @@ def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor
864
887
w_grad = torch .empty ((n , k ), dtype = torch .bfloat16 , device = o_grad .device )
865
888
deepgemm_fp8_gemm (o_grad_t_per_tok , ctx .x_t_per_blk , w_grad )
866
889
890
+ if ctx .has_batch_dim :
891
+ x_grad = x_grad .unsqueeze (1 )
892
+
867
893
return x_grad , w_grad
868
894
869
895
0 commit comments