Skip to content

Commit c407f65

Browse files
jwfrommfacebook-github-bot
authored andcommitted
BF16I4 Preshuffled Grouped Gemm (#3917)
Summary: X-link: facebookresearch/FBGEMM#1006 Pull Request resolved: #3917 This diff adds a preshuffled variant of BF16I4 Grouped Gemm. Notably, cutlass does not currently support zero points for grouped gemm, so this kernel must be used without them. That said, the accuracy of the kernel appears reasonable and the performance is very compelling. {F1976716898} Reviewed By: jiawenliu64 Differential Revision: D72337760 fbshipit-source-id: a2cf9e913d095da42f1cf88a5c08dbbe1f2794c9
1 parent 8cbb32c commit c407f65

File tree

4 files changed

+573
-3
lines changed

4 files changed

+573
-3
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -1527,7 +1527,6 @@ def preprocess(self, x, w):
15271527
# Convert m_values into offsets into grouped tensor.
15281528
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
15291529
# Quantize weights.
1530-
# TODO Only rowwise scaling is currently supported. This needs to be fixed.
15311530
wq, scales = zip(*[quantize_int4_preshuffle(i) for i in w])
15321531
group_scale, row_scale = zip(*scales)
15331532
# Group weights as single tensor.
@@ -1573,6 +1572,65 @@ def cuda(self) -> bool:
15731572
return True
15741573

15751574

1575+
@register_quantize_op
1576+
class BF16I4ShuffledGroupedGemm(QuantizeOpBase):
1577+
"""
1578+
BF16 x Int4 mixed dtype grouped gemm with preshuffling.
1579+
"""
1580+
1581+
def preprocess(self, x, w):
1582+
assert isinstance(x, list) and isinstance(
1583+
w, list
1584+
), "Only supported for grouped inputs."
1585+
m_values = [i.shape[0] for i in x]
1586+
# Convert m_values into offsets into grouped tensor.
1587+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
1588+
# Quantize weights.
1589+
wq, scales = zip(
1590+
*[quantize_int4_preshuffle(i, dtype="bf16", use_zp=False) for i in w]
1591+
)
1592+
# Group weights as single tensor.
1593+
group_scale, group_zero = zip(*scales)
1594+
wq = torch.stack(wq, dim=0).contiguous()
1595+
group_scale = torch.stack(group_scale, dim=0).contiguous()
1596+
group_zero = torch.stack(group_zero, dim=0).contiguous()
1597+
# Also view input as flattened.
1598+
x = torch.concat(x, dim=0).contiguous()
1599+
# Return processed tensors.
1600+
return x, wq, group_scale, group_zero, m_sizes
1601+
1602+
def quantize(self, x, wq, group_scale, group_zero, m_sizes):
1603+
return x, wq, group_scale, group_zero, m_sizes
1604+
1605+
def compute(self, x, wq, group_scale, group_zero, m_sizes):
1606+
# TODO Zero points arent currently supported in grouped gemm.
1607+
# We leave them as inputs for future compatibility but they are ignored.
1608+
return torch.ops.fbgemm.bf16i4bf16_shuffled_grouped(
1609+
x, wq, group_scale, group_zero, m_sizes
1610+
)
1611+
1612+
def quantize_and_compute(self, x, wq, group_scale, group_zero, m_sizes):
1613+
x, wq, group_scale, group_zero, m_sizes = self.quantize(
1614+
x, wq, group_scale, group_zero, m_sizes
1615+
)
1616+
return self.compute(x, wq, group_scale, group_zero, m_sizes)
1617+
1618+
@property
1619+
def name(self) -> str:
1620+
if torch.version.cuda:
1621+
return "cutlass_bf16i4_grouped_preshuffle"
1622+
else:
1623+
return "ck_bf16i4_grouped_preshuffle"
1624+
1625+
@property
1626+
def hip(self) -> bool:
1627+
return False
1628+
1629+
@property
1630+
def cuda(self) -> bool:
1631+
return True
1632+
1633+
15761634
@register_quantize_op
15771635
class BF16GroupedStacked(QuantizeOpBase):
15781636
"""

fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def int4_row_quantize(
9191

9292

9393
def quantize_int4_preshuffle(
94-
w: torch.Tensor, group_size: int = 128, dtype: str = "fp8"
94+
w: torch.Tensor, group_size: int = 128, dtype: str = "fp8", use_zp: bool = True
9595
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
9696
"""
9797
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
@@ -102,6 +102,7 @@ def quantize_int4_preshuffle(
102102
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
103103
group_size (int): Number of elements to calculate group scale for, must be at least 128.
104104
dtype (torch.dtype): Type of corresponding activations. Must be fp8 or bf16.
105+
use_zp (bool): If true, uses zero points during weight quantization. Only relevant for bf16 currently.
105106
Returns:
106107
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
107108
scales (Tuple[Tensor]): Scale tensors for the specified activation type. When FP8 is used,
@@ -128,7 +129,11 @@ def _quantize(
128129
return wq, (group_scale, row_scale)
129130

130131
elif dtype == "bf16":
131-
wq, group_scale, group_zero = int4_row_quantize_zp(w, group_size)
132+
if use_zp:
133+
wq, group_scale, group_zero = int4_row_quantize_zp(w, group_size)
134+
else:
135+
wq, group_scale = int4_row_quantize(w, group_size)
136+
group_zero = torch.zeros_like(group_scale)
132137
# Set scales to activation type.
133138
group_scale = group_scale.to(torch.bfloat16)
134139
group_zero = group_zero.to(torch.bfloat16)

0 commit comments

Comments
 (0)