diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index c549433c4..f69125aaf 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -3,15 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from . import research, utils
+
+from . import _ops, research, utils
from .autograd._functions import (
MatmulLtState,
- bmm_cublas,
matmul,
matmul_4bit,
- matmul_cublas,
- mm_cublas,
)
+from .backends.cpu import ops as cpu_ops
+from .backends.cuda import ops as cuda_ops ## TODO: We would guard this for CUDA only
from .nn import modules
from .optim import adam
diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py
new file mode 100644
index 000000000..c519194a6
--- /dev/null
+++ b/bitsandbytes/_ops.py
@@ -0,0 +1,316 @@
+from math import prod
+from typing import Optional, Sequence, Tuple
+
+import torch
+
+_IS_TORCH_GTE_24 = False
+
+if hasattr(torch.library, "register_fake"):
+ _IS_TORCH_GTE_24 = True
+ register_fake = torch.library.register_fake
+ register_kernel = torch.library.register_kernel
+else:
+ # PyTorch <= 2.3
+ register_fake = torch.library.impl_abstract
+ register_kernel = torch.library.impl
+
+
+# Higher level op: int8 matmul + dequant + bias
+torch.library.define(
+ "bitsandbytes::int8_scaled_mm",
+ "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
+)
+
+
+@register_fake("bitsandbytes::int8_scaled_mm")
+def _(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ row_stats: torch.Tensor,
+ col_stats: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ dtype=torch.float16,
+) -> torch.Tensor:
+ shapeC = (*A.shape[:-1], B.shape[0])
+ return torch.empty(shapeC, device=A.device, dtype=dtype)
+
+
+@register_kernel("bitsandbytes::int8_scaled_mm", None)
+def _(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ row_stats: torch.Tensor,
+ col_stats: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ dtype=torch.float16,
+) -> torch.Tensor:
+ out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
+ out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
+ return out
+
+
+torch.library.define(
+ "bitsandbytes::int8_linear_matmul",
+ "(Tensor A, Tensor B) -> Tensor",
+)
+
+
+@register_fake("bitsandbytes::int8_linear_matmul")
+def _(A: torch.Tensor, B: torch.Tensor):
+ torch._check(A.dtype == torch.int8, lambda: "A must be int8")
+ torch._check(B.dtype == torch.int8, lambda: "B must be int8")
+ shapeC = (*A.shape[:-1], B.shape[0])
+ return torch.empty(shapeC, device=A.device, dtype=torch.int32)
+
+
+# More info on `out` overloads:
+# https://github.com/pytorch/pytorch/issues/125044
+torch.library.define(
+ "bitsandbytes::int8_linear_matmul.out",
+ "(Tensor A, Tensor B, Tensor! out) -> ()",
+)
+
+
+@register_fake("bitsandbytes::int8_linear_matmul.out")
+def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
+ shapeC = (*A.shape[:-1], B.shape[0])
+
+ torch._check(A.dtype == torch.int8, lambda: "A must be int8")
+ torch._check(B.dtype == torch.int8, lambda: "B must be int8")
+ torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}")
+ torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
+ torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}")
+
+
+torch.library.define(
+ "bitsandbytes::int8_vectorwise_quant",
+ "(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)",
+)
+
+
+@register_fake("bitsandbytes::int8_vectorwise_quant")
+def _(A: torch.Tensor, threshold=0.0):
+ out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
+ row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
+
+ if threshold == 0.0:
+ return out_row, row_stats, None
+
+ outlier_cols = torch.library.get_ctx().new_dynamic_size()
+
+ return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64)
+
+
+torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor")
+
+
+@register_fake("bitsandbytes::int8_vectorwise_dequant")
+def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
+ torch._check(A.dtype == torch.int8, lambda: "A must be int8")
+ return torch.empty_like(A, dtype=torch.float32)
+
+
+# Default PyTorch-native implementation
+@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
+def _(A: torch.Tensor, stats: torch.Tensor):
+ # To dequantize we divide by 127, or multiply by the reciprocal.
+ return A * stats.view(-1, 1) * 7.874015718698502e-3
+
+
+torch.library.define(
+ "bitsandbytes::int8_mm_dequant",
+ "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
+)
+
+
+@register_fake("bitsandbytes::int8_mm_dequant")
+def _(
+ A: torch.Tensor,
+ row_stats: torch.Tensor,
+ col_stats: torch.Tensor,
+ dtype=torch.float16,
+ bias: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ torch._check(A.dtype == torch.int32, lambda: "A must be int32")
+ return torch.empty_like(A, dtype=dtype)
+
+
+torch.library.define(
+ "bitsandbytes::int8_double_quant",
+ "(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
+)
+
+
+@register_fake("bitsandbytes::int8_double_quant")
+def _(
+ A: torch.Tensor,
+ threshold=0.0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ out_row = torch.empty_like(A, dtype=torch.int8)
+ out_col = torch.empty_like(A, dtype=torch.int8)
+ row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
+ col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
+ outlier_n = torch.library.get_ctx().new_dynamic_size()
+ outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
+ return out_row, out_col, row_stats, col_stats, outlier_cols
+
+
+torch.library.define(
+ "bitsandbytes::dequantize_4bit",
+ "(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor",
+)
+
+
+@register_fake("bitsandbytes::dequantize_4bit")
+def _(
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ blocksize: int,
+ quant_type: str,
+ shape: Sequence[int],
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ torch._check_is_size(blocksize)
+ return torch.empty(shape, dtype=dtype, device=A.device)
+
+
+torch.library.define(
+ "bitsandbytes::dequantize_4bit.out",
+ "(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()",
+)
+
+
+@register_fake("bitsandbytes::dequantize_4bit.out")
+def _(
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ blocksize: int,
+ quant_type: str,
+ shape: Sequence[int],
+ dtype: torch.dtype,
+ out: torch.Tensor,
+) -> None:
+ torch._check_is_size(blocksize)
+ torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
+ torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
+ torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
+
+
+torch.library.define(
+ "bitsandbytes::quantize_4bit",
+ "(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
+)
+
+
+@register_fake("bitsandbytes::quantize_4bit")
+def _(
+ A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ torch._check_is_size(blocksize)
+
+ n = A.numel()
+ blocks = -(n // -blocksize)
+ absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
+ out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
+ return out, absmax
+
+
+torch.library.define(
+ "bitsandbytes::dequantize_blockwise",
+ "(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
+)
+
+
+@register_fake("bitsandbytes::dequantize_blockwise")
+def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
+ torch._check_is_size(blocksize)
+ torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
+ return torch.empty_like(A, dtype=dtype)
+
+
+torch.library.define(
+ "bitsandbytes::dequantize_blockwise.out",
+ "(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()",
+)
+
+
+@register_fake("bitsandbytes::dequantize_blockwise.out")
+def _(
+ A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
+):
+ torch._check_is_size(blocksize)
+ torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
+ torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
+ torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
+ torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
+
+
+torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")
+
+
+@register_fake("bitsandbytes::quantize_blockwise")
+def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ torch._check_is_size(blocksize)
+ n = A.numel()
+ blocks = -(n // -blocksize)
+ absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
+ out = torch.empty_like(A, dtype=torch.uint8)
+ return out, absmax
+
+
+torch.library.define(
+ "bitsandbytes::gemv_4bit",
+ "(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor",
+)
+
+
+@register_fake("bitsandbytes::gemv_4bit")
+def _(
+ A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
+) -> torch.Tensor:
+ torch._check_is_size(blocksize)
+ torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
+ torch._check(
+ A.dtype in [torch.float16, torch.bfloat16, torch.float32],
+ lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
+ )
+ torch._check(
+ B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
+ lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
+ )
+ shape = (*A.shape[:-1], shapeB[0])
+ return torch.empty(shape, device=A.device, dtype=A.dtype)
+
+
+torch.library.define(
+ "bitsandbytes::gemv_4bit.out",
+ "(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()",
+)
+
+
+@register_fake("bitsandbytes::gemv_4bit.out")
+def _(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ shapeB: Sequence[int],
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+ out: torch.Tensor,
+) -> None:
+ torch._check_is_size(blocksize)
+ torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
+ torch._check(
+ A.dtype in [torch.float16, torch.bfloat16, torch.float32],
+ lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
+ )
+ torch._check(
+ B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
+ lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
+ )
+ torch._check(
+ out.shape == (*A.shape[:-1], shapeB[0]),
+ lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
+ )
+ torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
+ torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index f66cdf68d..9f14db754 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -49,6 +49,10 @@ def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
+@deprecated(
+ "This function is deprecated and will be removed in a future release.",
+ category=FutureWarning,
+)
def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int],
@@ -80,6 +84,10 @@ def get_inverse_transform_indices(
return permuted_tile_indices
+@deprecated(
+ "This function is deprecated and will be removed in a future release.",
+ category=FutureWarning,
+)
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout
@@ -98,152 +106,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
return outputs.reshape(rows, cols).contiguous()
-@deprecated(
- "MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.",
- category=FutureWarning,
-)
-class MatMul8bit(torch.autograd.Function):
- @staticmethod
- def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
- if precision is None:
- precision = [8, 8, 8]
- if precision[0] != 8:
- with torch.no_grad():
- output = torch.matmul(A, B)
- else:
- if len(B.shape) == 2:
- dim = 0
- else:
- dim = 1
- qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
- qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
- iout = F.igemm(qA, qB)
- output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
-
- if A.requires_grad or B.requires_grad:
- ctx.save_for_backward(A, B)
-
- ctx.quant_type = quant_type
- ctx.precision = precision
-
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- A, B = ctx.saved_tensors
- quant_type = ctx.quant_type
- precision = ctx.precision
- grad_A = grad_B = None
-
- if B.requires_grad:
- if len(A.shape) == 3:
- dims = [0, 1]
- # bsi -> ibs
- permute_dim = [0, 2, 1]
- else:
- dims = [0]
- # bs -> sb
- permute_dim = [1, 0]
-
- if precision[1] != 8:
- with torch.no_grad():
- grad_B = torch.matmul(A.permute(permute_dim), grad_output)
- else:
- if len(B.shape) == 2 and len(A.shape) == 3:
- grad_output = grad_output.contiguous()
- if not grad_output.is_contiguous():
- grad_output.contiguous()
- qgrad_output, S1 = F.vectorwise_quant(
- grad_output.view(-1, grad_output.shape[2]),
- dim=0,
- quant_type=quant_type,
- )
- if not A.is_contiguous():
- A = A.contiguous()
- qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
- igrad_B = F.igemm(qA.t(), qgrad_output)
- grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
- else:
- qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
- qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
- igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
- grad_B = F.vectorwise_mm_dequant(
- igrad_B,
- S2.permute(permute_dim),
- S1,
- grad_output.dtype,
- quant_type,
- )
-
- if A.requires_grad:
- if len(grad_output.shape) == 3:
- dims = [2]
- else:
- dims = [1]
-
- if len(B.shape) == 3:
- # bio -> boi
- permute_dim = [0, 2, 1]
- dim_B = dims
- else:
- # io -> oi
- permute_dim = [1, 0]
- dim_B = [1]
-
- if precision[2] != 8:
- with torch.no_grad():
- grad_A = torch.matmul(grad_output, B.permute(permute_dim))
- else:
- qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
- qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
- igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
- grad_A = F.vectorwise_mm_dequant(
- igrad_A,
- S1,
- S3.permute(permute_dim),
- grad_output.dtype,
- quant_type,
- )
-
- return grad_A, grad_B, None, None, None
-
-
-mm_cublas = MatMul8bit.apply
-bmm_cublas = MatMul8bit.apply
-matmul_cublas = MatMul8bit.apply
-
-
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def supports_igemmlt(device: torch.device) -> bool:
- """check if this device supports the optimized int8 kernel"""
- if torch.cuda.get_device_capability(device=device) < (7, 5):
- return False
- device_name = torch.cuda.get_device_name(device=device)
- nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
- if any(model_name in device_name for model_name in nvidia16_models):
- return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
- return True
-
-
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def _get_tile_size(format):
- assert format in (
- "col_turing",
- "col_ampere",
- ), f"please find this assert and manually enter tile size for {format}"
- return (8, 32) if format == "col_turing" else (32, 32)
-
-
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def get_tile_inds(format, device):
- transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
- with torch.no_grad():
- return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
-
-
@dataclass
class MatmulLtState:
- _tile_indices: Optional[torch.Tensor] = None
+ _tile_indices: Optional[torch.Tensor] = None # TODO: remove
force_no_igemmlt: bool = False
@@ -279,9 +144,7 @@ def reset_grads(self):
@property
def tile_indices(self):
- if self._tile_indices is None:
- self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
- return self._tile_indices
+ raise ValueError("tile_indices is no longer supported.")
class MatMul8bitLt(torch.autograd.Function):
@@ -360,20 +223,12 @@ def forward(
# we want to divide by 127. It's however more performant to multiply
# by the reciprocal.
outliers = state.CB[:, state.idx]
- state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype)
+ state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t()
else:
subA = None
- # 3. Int8 Matmul
- out32 = F.int8_linear_matmul(CA, state.CB)
-
- # Dequantize matmul result
- if bias is None or bias.dtype == torch.float16:
- # we apply the fused bias here
- output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
- else: # apply bias separately
- # TODO: Fused bias for fp32/bf16?
- output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)
+ # 3. Int8 Matmul + Dequant + Bias
+ output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)
# 4. Mixed-precision decomposition matmul
if subA is not None and state.subB is not None:
@@ -423,8 +278,14 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
if req_gradB:
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
- gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t())
- grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt)
+ grad_B = torch.ops.bitsandbytes.int8_scaled_mm.default(
+ Cgrad.t().contiguous(),
+ CAt.t(),
+ SCgradt,
+ SCAt,
+ dtype=torch.float16,
+ )
+
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/bitsandbytes/backends/cpu/__init__.py b/bitsandbytes/backends/cpu/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py
new file mode 100644
index 000000000..07e3cafd4
--- /dev/null
+++ b/bitsandbytes/backends/cpu/ops.py
@@ -0,0 +1,94 @@
+import ctypes as ct
+from typing import Optional, Tuple
+
+import torch
+
+from bitsandbytes.functional import get_ptr
+
+from ..._ops import register_kernel
+from ...cextension import lib
+
+
+@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
+def _(A: torch.Tensor, B: torch.Tensor):
+ return _int8_linear_matmul_impl(A, B)
+
+
+@register_kernel("bitsandbytes::int8_linear_matmul.out", "cpu")
+def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
+ torch._check(out.dtype == torch.int32)
+ _int8_linear_matmul_impl(A, B, out)
+
+
+def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None):
+ # Naive implementation: perform matmul in fp32
+ result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
+ if out is not None:
+ result = out.copy_(result)
+ return result
+
+
+@register_kernel("bitsandbytes::int8_mm_dequant", "cpu")
+def _(
+ A: torch.Tensor,
+ row_stats: torch.Tensor,
+ col_stats: torch.Tensor,
+ dtype=torch.float16,
+ bias: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
+ torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
+ torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
+
+ A_calc = A.view(-1, A.shape[-1])
+ row_stats = row_stats.reshape(-1).unsqueeze(-1)
+ col_stats = col_stats.reshape(-1).unsqueeze(0)
+
+ out = A_calc * (row_stats * col_stats) * 6.200124e-05
+ if bias is not None:
+ out += bias
+
+ return out.to(dtype)
+
+
+@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
+def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ torch._check_is_size(blocksize)
+ torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")
+
+ n = A.numel()
+ blocks = -(n // -blocksize)
+
+ absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
+ out = torch.empty_like(A, dtype=torch.uint8)
+
+ lib.cquantize_blockwise_cpu_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_longlong(blocksize),
+ ct.c_longlong(n),
+ )
+
+ return out, absmax
+
+
+@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
+def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
+ torch._check_is_size(blocksize)
+ torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
+ torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")
+
+ out = torch.empty_like(A, dtype=dtype)
+
+ lib.cdequantize_blockwise_cpu_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_longlong(blocksize),
+ ct.c_longlong(A.numel()),
+ )
+
+ return out
diff --git a/bitsandbytes/backends/cuda/__init__.py b/bitsandbytes/backends/cuda/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py
new file mode 100644
index 000000000..88bd872be
--- /dev/null
+++ b/bitsandbytes/backends/cuda/ops.py
@@ -0,0 +1,520 @@
+import ctypes as ct
+from math import prod
+from typing import Optional, Sequence, Tuple
+
+import torch
+
+from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
+
+from ..._ops import register_kernel
+from ...cextension import lib
+
+
+@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
+def _(A: torch.Tensor, B: torch.Tensor):
+ out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32)
+ return _int8_linear_matmul_impl(A, B, out)
+
+
+@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda")
+def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
+ _int8_linear_matmul_impl(A, B, out)
+
+
+def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
+ A, B = B, A
+
+ shapeA = A.shape
+ shapeB = B.shape
+
+ torch._check(A.dtype == torch.int8, lambda: "B must be int8")
+ torch._check(B.dtype == torch.int8, lambda: "A must be int8")
+ torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B")
+ torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A")
+ torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}")
+ torch._check(out.dtype == torch.int32)
+
+ shapeC = (*shapeB[:-1], shapeA[0])
+ torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}")
+
+ k, m = shapeA
+ n = prod(shapeB[:-1])
+ lda = shapeA[-1] # Weights (outputs, inputs)
+ ldb = shapeB[-1] # Activations (batch, tokens, inputs)
+ ldc = shapeC[-1] # Output (batch, tokens, outputs)
+
+ torch._check(
+ lda == ldb,
+ lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}",
+ )
+
+ # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
+ # We'll fall back to a slower fp32 calculation in this circumstance.
+ # Fortunately, this should not be very common.
+ if lda % 4 != 0:
+ result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
+ return out.copy_(result)
+
+ with _cuda_device_of(A):
+ ctx = CUBLAS_Context.get_instance().get_context(A.device)
+ ptrA = get_ptr(A)
+ ptrB = get_ptr(B)
+ ptrC = get_ptr(out)
+ ptrRowScale = None
+ m = ct.c_int32(m)
+ n = ct.c_int32(n)
+ k = ct.c_int32(k)
+ lda = ct.c_int32(lda)
+ ldb = ct.c_int32(ldb)
+ ldc = ct.c_int32(ldc)
+ stream = _get_tensor_stream(A)
+
+ has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
+
+ if has_error:
+ if has_error == 100:
+ # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
+ # TODO: Warn and implement a fallback to fp32 compute?
+ raise NotImplementedError("int8_linear_matmul not implemented!")
+ else:
+ raise RuntimeError(
+ f"cublasLt ran into an error!\n"
+ f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
+ f"\t{(lda, ldb, ldc)=}\n"
+ f"\t{(m, n, k)=}"
+ )
+
+ return out
+
+
+@register_kernel("bitsandbytes::int8_mm_dequant", "cuda")
+def _(
+ A: torch.Tensor,
+ row_stats: torch.Tensor,
+ col_stats: torch.Tensor,
+ dtype=torch.float16,
+ bias: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
+ torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
+ torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
+
+ # Note: cuda kernel only currently supports fp16 output.
+ # We'll later cast to desired dtype if needed.
+ out = torch.empty_like(A, dtype=torch.float16)
+
+ ptrA = get_ptr(A)
+ ptrOut = get_ptr(out)
+ ptrRowStats = get_ptr(row_stats)
+ ptrColStats = get_ptr(col_stats)
+ numRows = ct.c_int32(prod(A.shape[:-1]))
+ numCols = ct.c_int32(A.shape[-1])
+
+ # Note: fused bias in the kernel is only supported for fp16
+ # TODO(matthewdouglas): Consider supporting bf16 fused bias
+ ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None
+
+ with _cuda_device_of(A):
+ lib.cdequant_mm_int32_fp16(
+ ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
+ )
+
+ # Add bias separately if not fused in kernel
+ if bias is not None and bias.dtype != torch.float16:
+ out.add_(bias)
+
+ return out.to(dtype)
+
+
+@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
+def _(A: torch.Tensor, threshold=0.0):
+ torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}")
+ torch._check(threshold >= 0.0, lambda: "threshold must be non-negative")
+
+ rows = prod(A.shape[:-1])
+ cols = A.shape[-1]
+
+ row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
+ out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
+
+ outlier_cols = None
+
+ if threshold > 0.0:
+ # TODO we could improve perf of this
+ outliers = A.abs() >= threshold
+
+ if outliers.any():
+ outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
+
+ with _cuda_device_of(A):
+ lib.cint8_vector_quant(
+ get_ptr(A),
+ get_ptr(out_row),
+ get_ptr(row_stats),
+ ct.c_float(threshold),
+ ct.c_int32(rows),
+ ct.c_int32(cols),
+ _get_tensor_stream(A),
+ )
+
+ # Zero out values from outlier columns across all rows.
+ # The kernel will handle this for outliers themselves, so we can optimize for rows=1.
+ if rows > 1 and outlier_cols is not None:
+ out_row[:, outlier_cols] = 0
+
+ return out_row, row_stats, outlier_cols
+
+
+@register_kernel("bitsandbytes::int8_double_quant", "cuda")
+def _(
+ A: torch.Tensor,
+ threshold=0.0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ # Use CUDA kernel for rowwise and COO tensor
+ quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default(
+ A,
+ threshold=threshold,
+ )
+
+ # PyTorch impl for colwise
+ col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold)
+ if threshold > 0.0 and outlier_mask is not None:
+ A = A.masked_fill(outlier_mask, 0.0)
+ quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8)
+
+ return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
+
+
+def _get_col_absmax(
+ A: torch.Tensor,
+ threshold=0.0,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ torch._check(A.is_floating_point())
+
+ outlier_mask = None
+
+ absA = A.abs().view(-1, A.shape[-1])
+
+ if threshold > 0.0:
+ # Filter outliers from stats when enabled
+ outlier_mask = absA >= threshold
+ absA.masked_fill_(outlier_mask, 0.0)
+
+ # shape [cols]; unsqueeze(0) gives [1,cols]
+ col_stats = absA.amax(dim=0, keepdim=False).float()
+
+ return col_stats, outlier_mask
+
+
+@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
+def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ torch._check_is_size(blocksize)
+ torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
+ torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
+
+ n = A.numel()
+ blocks = -(n // -blocksize)
+ absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
+ out = torch.empty_like(A, dtype=torch.uint8)
+
+ with _cuda_device_of(A):
+ args = (
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int32(blocksize),
+ ct.c_int(A.numel()),
+ )
+
+ if A.dtype == torch.float16:
+ lib.cquantize_blockwise_fp16(*args)
+ elif A.dtype == torch.bfloat16:
+ lib.cquantize_blockwise_bf16(*args)
+ elif A.dtype == torch.float32:
+ lib.cquantize_blockwise_fp32(*args)
+ else:
+ raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
+
+ return out, absmax
+
+
+@register_kernel("bitsandbytes::dequantize_blockwise", "cuda")
+def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
+ out = torch.empty_like(A, dtype=dtype)
+ _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
+ return out
+
+
+@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda")
+def _(
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+ dtype: torch.dtype,
+ out: torch.Tensor,
+) -> None:
+ torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
+ torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
+ _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
+
+
+def _dequantize_blockwise_impl(
+ A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
+) -> None:
+ torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
+ torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
+ torch._check(
+ dtype in [torch.float16, torch.bfloat16, torch.float32],
+ lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}",
+ )
+
+ with _cuda_device_of(A):
+ args = (
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(blocksize),
+ ct.c_int(A.numel()),
+ _get_tensor_stream(A),
+ )
+
+ if dtype == torch.float16:
+ lib.cdequantize_blockwise_fp16(*args)
+ elif dtype == torch.bfloat16:
+ lib.cdequantize_blockwise_bf16(*args)
+ elif dtype == torch.float32:
+ lib.cdequantize_blockwise_fp32(*args)
+
+
+@register_kernel("bitsandbytes::quantize_4bit", "cuda")
+def _(
+ A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
+ torch._check(quant_type in ["fp4", "nf4"])
+ torch._check(
+ A.dtype in [torch.bfloat16, torch.float16, torch.float32],
+ lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
+ )
+
+ n = A.numel()
+ blocks = -(n // -blocksize)
+ absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
+ out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
+
+ with _cuda_device_of(A):
+ args = (
+ None,
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int32(blocksize),
+ ct.c_int(n),
+ )
+
+ if A.dtype == torch.bfloat16:
+ if quant_type == "fp4":
+ lib.cquantize_blockwise_bf16_fp4(*args)
+ else:
+ lib.cquantize_blockwise_bf16_nf4(*args)
+ elif A.dtype == torch.float16:
+ if quant_type == "fp4":
+ lib.cquantize_blockwise_fp16_fp4(*args)
+ else:
+ lib.cquantize_blockwise_fp16_nf4(*args)
+ elif A.dtype == torch.float32:
+ if quant_type == "fp4":
+ lib.cquantize_blockwise_fp32_fp4(*args)
+ else:
+ lib.cquantize_blockwise_fp32_nf4(*args)
+
+ return out, absmax
+
+
+@register_kernel("bitsandbytes::dequantize_4bit", "cuda")
+def _(
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ blocksize: int,
+ quant_type: str,
+ shape: Sequence[int],
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ out = torch.empty(shape, dtype=dtype, device=A.device)
+ _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
+ return out
+
+
+@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda")
+def _(
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ blocksize: int,
+ quant_type: str,
+ shape: Sequence[int],
+ dtype: torch.dtype,
+ out: torch.Tensor,
+) -> None:
+ torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
+ torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
+ _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
+
+
+def _dequantize_4bit_impl(
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ blocksize: int,
+ quant_type: str,
+ dtype: torch.dtype,
+ out: torch.Tensor,
+) -> None:
+ torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
+ torch._check(quant_type in ["fp4", "nf4"])
+ torch._check(
+ dtype in [torch.bfloat16, torch.float16, torch.float32],
+ lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
+ )
+
+ with _cuda_device_of(A):
+ args = (
+ None,
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(blocksize),
+ ct.c_int(out.numel()),
+ _get_tensor_stream(A),
+ )
+
+ if out.dtype == torch.bfloat16:
+ if quant_type == "fp4":
+ lib.cdequantize_blockwise_bf16_fp4(*args)
+ else:
+ lib.cdequantize_blockwise_bf16_nf4(*args)
+ elif out.dtype == torch.float16:
+ if quant_type == "fp4":
+ lib.cdequantize_blockwise_fp16_fp4(*args)
+ else:
+ lib.cdequantize_blockwise_fp16_nf4(*args)
+ elif out.dtype == torch.float32:
+ if quant_type == "fp4":
+ lib.cdequantize_blockwise_fp32_fp4(*args)
+ else:
+ lib.cdequantize_blockwise_fp32_nf4(*args)
+
+
+@register_kernel("bitsandbytes::gemv_4bit", "cuda")
+def _(
+ A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
+) -> torch.Tensor:
+ shape = (*A.shape[:-1], shapeB[0])
+ out = torch.empty(shape, device=A.device, dtype=A.dtype)
+ _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
+ return out
+
+
+@register_kernel("bitsandbytes::gemv_4bit.out", "cuda")
+def _(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ shapeB: Sequence[int],
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+ out: torch.Tensor,
+) -> None:
+ torch._check(
+ out.shape == (*A.shape[:-1], shapeB[0]),
+ lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
+ )
+ torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
+ _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
+
+
+def _gemv_4bit_impl(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ shapeB: Sequence[int],
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+ out: torch.Tensor,
+) -> None:
+ torch._check_is_size(blocksize)
+ torch._check(
+ A.numel() == A.size(-1),
+ lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
+ )
+ torch._check(
+ A.dtype in [torch.float16, torch.bfloat16, torch.float32],
+ lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
+ )
+ torch._check(
+ B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
+ lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
+ )
+ torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
+ torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
+
+ m = ct.c_int32(shapeB[0])
+ n = ct.c_int32(1)
+ k = ct.c_int32(shapeB[1])
+
+ lda = m
+ ldb = ct.c_int32((A.shape[-1] + 1) // 2)
+ ldc = m
+
+ stream = _get_tensor_stream(A)
+
+ with _cuda_device_of(A):
+ if A.dtype == torch.float16:
+ lib.cgemm_4bit_inference_naive_fp16(
+ m,
+ n,
+ k,
+ get_ptr(A),
+ get_ptr(B),
+ get_ptr(absmax),
+ get_ptr(code),
+ get_ptr(out),
+ lda,
+ ldb,
+ ldc,
+ ct.c_int32(blocksize),
+ stream,
+ )
+ elif A.dtype == torch.bfloat16:
+ lib.cgemm_4bit_inference_naive_bf16(
+ m,
+ n,
+ k,
+ get_ptr(A),
+ get_ptr(B),
+ get_ptr(absmax),
+ get_ptr(code),
+ get_ptr(out),
+ lda,
+ ldb,
+ ldc,
+ ct.c_int32(blocksize),
+ stream,
+ )
+ elif A.dtype == torch.float32:
+ lib.cgemm_4bit_inference_naive_fp32(
+ m,
+ n,
+ k,
+ get_ptr(A),
+ get_ptr(B),
+ get_ptr(absmax),
+ get_ptr(code),
+ get_ptr(out),
+ lda,
+ ldb,
+ ldc,
+ ct.c_int32(blocksize),
+ stream,
+ )
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 5894e52dd..c0e139e03 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -182,13 +182,6 @@ def get_instance(cls):
return cls._instance
-dtype2bytes = {}
-dtype2bytes[torch.float32] = 4
-dtype2bytes[torch.float16] = 2
-dtype2bytes[torch.bfloat16] = 2
-dtype2bytes[torch.uint8] = 1
-dtype2bytes[torch.int8] = 1
-
FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
# When multiple GPUs are present, we use a context manager to
@@ -207,7 +200,7 @@ def _cuda_device_of(a: torch.Tensor):
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
- num_bytes = dtype2bytes[dtype] * prod(shape)
+ num_bytes = dtype.itemsize * prod(shape)
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
@@ -217,15 +210,14 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
return out
-def prefetch_tensor(A, to_cpu=False):
+def prefetch_tensor(A: torch.Tensor, to_cpu=False):
assert A.is_paged, "Only paged tensors can be prefetched!"
if to_cpu:
deviceid = -1
else:
deviceid = A.page_deviceid
- num_bytes = dtype2bytes[A.dtype] * A.numel()
- lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
+ lib.cprefetch(get_ptr(A), ct.c_size_t(A.nbytes), ct.c_int32(deviceid))
def elementwise_func(func_name, A, B, value, prefetch=True):
@@ -431,11 +423,6 @@ def create_quantile_map(A, total_bits=8):
return q
-@deprecated("This function is deprecated and will be removed in a future version.", category=FutureWarning)
-def get_special_format_str():
- return "row"
-
-
def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
"""Verifies that the input tensors are all on the same device.
@@ -472,11 +459,6 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
return on_gpu
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream:
- return torch.cuda.current_stream(tensor.device)
-
-
def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons.
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
@@ -509,106 +491,6 @@ def post_call(prev_device):
torch.cuda.set_device(prev_device)
-@deprecated(
- "The layout transformation operations will be removed in a future release. Please use row-major layout only.",
- category=FutureWarning,
-)
-def get_transform_func(dtype, orderA, orderOut, transpose=False):
- name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
- if not hasattr(lib, name):
- print(name)
- raise ValueError(
- f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}",
- )
- else:
- return getattr(lib, name)
-
-
-@deprecated(
- "The layout transformation operations will be removed in a future release. Please use row-major layout only.",
- category=FutureWarning,
-)
-def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
- # init_func = torch.empty
- init_func = torch.zeros
- dims = len(shape)
-
- if dims == 2:
- rows = shape[0]
- elif dims == 3:
- rows = shape[0] * shape[1]
- cols = shape[-1]
-
- state = (shape, to_order)
- if transpose:
- # swap dims
- tmp = rows
- rows = cols
- cols = tmp
- state = (shape[::-1], to_order)
-
- if to_order == "row" or to_order == "col":
- return init_func(shape, dtype=dtype, device=device), state
- elif to_order == "col32":
- # blocks of 32 columns (padded)
- cols = 32 * ((cols + 31) // 32)
- return init_func((rows, cols), dtype=dtype, device=device), state
- elif to_order == "col_turing":
- # blocks of 32 columns and 8 rows
- cols = 32 * ((cols + 31) // 32)
- rows = 8 * ((rows + 7) // 8)
- return init_func((rows, cols), dtype=dtype, device=device), state
- elif to_order == "col_ampere":
- # blocks of 32 columns and 32 rows
- cols = 32 * ((cols + 31) // 32)
- rows = 32 * ((rows + 31) // 32)
- return init_func((rows, cols), dtype=dtype, device=device), state
- else:
- raise NotImplementedError(f"To_order not supported: {to_order}")
-
-
-@deprecated(
- "The layout transformation operations will be removed in a future release. Please use row-major layout only.",
- category=FutureWarning,
-)
-def nvidia_transform(
- A,
- to_order,
- from_order="row",
- out=None,
- transpose=False,
- state=None,
- ld=None,
-):
- if state is None:
- state = (A.shape, from_order)
- else:
- from_order = state[1]
- if out is None:
- out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
- else:
- new_state = (state[1], to_order)
- func = get_transform_func(A.dtype, from_order, to_order, transpose)
-
- shape = state[0]
- if len(shape) == 2:
- dim1 = ct.c_int32(shape[0])
- dim2 = ct.c_int32(shape[1])
- elif ld is not None:
- n = prod(shape)
- dim1 = prod([shape[i] for i in ld])
- dim2 = ct.c_int32(n // dim1)
- dim1 = ct.c_int32(dim1)
- else:
- dim1 = ct.c_int32(shape[0] * shape[1])
- dim2 = ct.c_int32(shape[2])
-
- ptr = CUBLAS_Context.get_instance().get_context(A.device)
- func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
-
- return out, new_state
-
-
def estimate_quantiles(
A: Tensor,
out: Optional[torch.Tensor] = None,
@@ -892,56 +774,16 @@ def quantize_blockwise(
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
code = name2qmap["dynamic"]
- if absmax is None:
- n = A.numel()
- blocks = -(n // -blocksize)
- absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
-
- if out is None:
- out = torch.zeros_like(A, dtype=torch.uint8)
-
- if A.device.type != "cpu":
- assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
-
- code = code.to(A.device)
-
- is_on_gpu([A, out, absmax])
-
- with _cuda_device_of(A):
- args = (
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int32(blocksize),
- ct.c_int(A.numel()),
- )
-
- if A.dtype == torch.float16:
- lib.cquantize_blockwise_fp16(*args)
- elif A.dtype == torch.bfloat16:
- lib.cquantize_blockwise_bf16(*args)
- elif A.dtype == torch.float32:
- lib.cquantize_blockwise_fp32(*args)
- else:
- raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
-
- else:
- # cpu
- code = code.cpu()
- lib.cquantize_blockwise_cpu_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_longlong(blocksize),
- ct.c_longlong(A.numel()),
- )
+ _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default(
+ A,
+ code.to(A.device),
+ blocksize,
+ )
if nested:
- offset = absmax.mean()
- absmax -= offset
- qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
+ offset = _absmax.mean()
+ _absmax -= offset
+ qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False)
quant_state = QuantState(
absmax=qabsmax,
code=code,
@@ -951,7 +793,14 @@ def quantize_blockwise(
state2=state2,
)
else:
- quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype)
+ quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype)
+
+ # TODO(matthewdouglas): Deprecate out kwarg
+ out = out.copy_(_out) if out is not None else _out
+
+ # TODO(matthewdouglas): Deprecate absmax kwarg
+ if absmax is not None:
+ quant_state.absmax = absmax.copy_(quant_state.absmax)
return out, quant_state
@@ -1013,49 +862,24 @@ def dequantize_blockwise(
if absmax.dtype != torch.float32:
absmax = absmax.float()
- if out is None:
- out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
-
- if A.device.type != "cpu":
- code = quant_state.code.to(A.device)
- if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]:
- raise ValueError(
- f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]",
- )
-
- is_on_gpu([A, absmax, out])
-
- with _cuda_device_of(A):
- args = (
- get_ptr(quant_state.code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int(quant_state.blocksize),
- ct.c_int(A.numel()),
- _get_tensor_stream(A),
- )
-
- if out.dtype == torch.float16:
- lib.cdequantize_blockwise_fp16(*args)
- elif out.dtype == torch.bfloat16:
- lib.cdequantize_blockwise_bf16(*args)
- elif out.dtype == torch.float32:
- lib.cdequantize_blockwise_fp32(*args)
- else:
- raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
- else:
- code = quant_state.code.cpu()
- lib.cdequantize_blockwise_cpu_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_longlong(quant_state.blocksize),
- ct.c_longlong(A.numel()),
+ if out is not None:
+ torch.ops.bitsandbytes.dequantize_blockwise.out(
+ A,
+ absmax,
+ code.to(A.device),
+ blocksize,
+ quant_state.dtype,
+ out=out,
)
+ return out
- return out
+ return torch.ops.bitsandbytes.dequantize_blockwise.default(
+ A,
+ absmax,
+ quant_state.code.to(A.device),
+ quant_state.blocksize,
+ quant_state.dtype,
+ )
def get_4bit_type(typename, device=None, blocksize=64):
@@ -1194,62 +1018,21 @@ def quantize_4bit(
- `torch.Tensor`: The quantized tensor with packed 4-bit values.
- [`QuantState`]: The state object used to undo the quantization.
"""
-
- if A.device.type != "cuda":
- raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
- if quant_type not in ["fp4", "nf4"]:
- raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
-
- n = A.numel()
input_shape = A.shape
- if absmax is None:
- blocks = -(n // -blocksize)
- absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
-
- if out is None:
- mod = dtype2bytes[quant_storage] * 2
- out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device)
-
- assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
-
- is_on_gpu([A, out, absmax])
-
- with _cuda_device_of(A):
- args = (
- None,
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int32(blocksize),
- ct.c_int(n),
- )
-
- if A.dtype == torch.bfloat16:
- if quant_type == "fp4":
- lib.cquantize_blockwise_bf16_fp4(*args)
- else:
- lib.cquantize_blockwise_bf16_nf4(*args)
- elif A.dtype == torch.float16:
- if quant_type == "fp4":
- lib.cquantize_blockwise_fp16_fp4(*args)
- else:
- lib.cquantize_blockwise_fp16_nf4(*args)
- elif A.dtype == torch.float32:
- if quant_type == "fp4":
- lib.cquantize_blockwise_fp32_fp4(*args)
- else:
- lib.cquantize_blockwise_fp32_nf4(*args)
- else:
- raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
+ _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default(
+ A,
+ blocksize,
+ quant_type,
+ quant_storage,
+ )
code = get_4bit_type(quant_type, device=A.device)
if compress_statistics:
- offset = absmax.mean()
- absmax -= offset
- qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
- del absmax
+ offset = _absmax.mean()
+ qabsmax, state2 = quantize_blockwise(_absmax - offset, blocksize=256)
+ del _absmax
state = QuantState(
absmax=qabsmax,
shape=input_shape,
@@ -1262,7 +1045,7 @@ def quantize_4bit(
)
else:
state = QuantState(
- absmax=absmax,
+ absmax=_absmax,
shape=input_shape,
dtype=A.dtype,
blocksize=blocksize,
@@ -1270,6 +1053,13 @@ def quantize_4bit(
quant_type=quant_type,
)
+ # TODO(matthewdouglas): Deprecate out kwarg
+ out = out.copy_(_out) if out is not None else _out
+
+ # TODO(matthewdouglas): Deprecate absmax kwarg
+ if absmax is not None:
+ state.absmax = absmax.copy_(state.absmax)
+
return out, state
@@ -1327,14 +1117,6 @@ def dequantize_4bit(
Returns:
`torch.Tensor`: The dequantized tensor.
"""
-
- if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
- raise ValueError(
- f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
- )
- if quant_type not in ["fp4", "nf4"]:
- raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
-
if quant_state is None:
assert absmax is not None and out is not None
@@ -1355,42 +1137,19 @@ def dequantize_4bit(
if absmax.dtype != torch.float32:
absmax = absmax.float()
- if out is None:
- out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)
-
- n = out.numel()
-
- is_on_gpu([A, absmax, out])
- stream = _get_tensor_stream(A)
-
- with _cuda_device_of(A):
- args = (
- None,
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int(quant_state.blocksize),
- ct.c_int(n),
- stream,
+ if out is not None:
+ torch.ops.bitsandbytes.dequantize_4bit.out(
+ A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
+ )
+ else:
+ out = torch.ops.bitsandbytes.dequantize_4bit.default(
+ A,
+ absmax,
+ quant_state.blocksize,
+ quant_state.quant_type,
+ quant_state.shape,
+ quant_state.dtype,
)
-
- if out.dtype == torch.bfloat16:
- if quant_state.quant_type == "fp4":
- lib.cdequantize_blockwise_bf16_fp4(*args)
- else:
- lib.cdequantize_blockwise_bf16_nf4(*args)
- elif out.dtype == torch.float16:
- if quant_state.quant_type == "fp4":
- lib.cdequantize_blockwise_fp16_fp4(*args)
- else:
- lib.cdequantize_blockwise_fp16_nf4(*args)
- elif out.dtype == torch.float32:
- if quant_state.quant_type == "fp4":
- lib.cdequantize_blockwise_fp32_fp4(*args)
- else:
- lib.cdequantize_blockwise_fp32_nf4(*args)
- else:
- raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
if A.shape[0] == 1: # is transposed, transpose back
return out.t()
@@ -1849,6 +1608,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
return current_gnorm, clip_value, gnorm_scale
+@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
assert len(histogram.shape) == 2
assert histogram.dtype == torch.float32
@@ -1959,100 +1719,33 @@ def gemv_4bit(
transposed_B=False,
state=None,
):
- # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None:
raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
- if A.numel() != A.shape[-1]:
- raise ValueError(
- 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]',
- )
-
- Bshape = state.shape
- bout = Bshape[0]
absmax = state.absmax
if state.nested:
- absmax = dequantize_blockwise(state.absmax, state.state2)
- absmax += state.offset
-
- if out is None:
- if len(A.shape) == 3:
- out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
- else:
- out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
-
- n = 1
- m = Bshape[0]
- k = Bshape[1]
- lda = Bshape[0]
- ldc = Bshape[0]
- ldb = (A.shape[-1] + 1) // 2
- is_on_gpu([B, A, out, absmax, state.code])
- m = ct.c_int32(m)
- n = ct.c_int32(n)
- k = ct.c_int32(k)
- lda = ct.c_int32(lda)
- ldb = ct.c_int32(ldb)
- ldc = ct.c_int32(ldc)
- stream = _get_tensor_stream(A)
-
- with _cuda_device_of(A):
- if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
- if A.dtype == torch.float16:
- lib.cgemm_4bit_inference_naive_fp16(
- m,
- n,
- k,
- get_ptr(A),
- get_ptr(B),
- get_ptr(absmax),
- get_ptr(state.code),
- get_ptr(out),
- lda,
- ldb,
- ldc,
- ct.c_int32(state.blocksize),
- stream,
- )
- elif A.dtype == torch.bfloat16:
- lib.cgemm_4bit_inference_naive_bf16(
- m,
- n,
- k,
- get_ptr(A),
- get_ptr(B),
- get_ptr(absmax),
- get_ptr(state.code),
- get_ptr(out),
- lda,
- ldb,
- ldc,
- ct.c_int32(state.blocksize),
- stream,
- )
- elif A.dtype == torch.float32:
- lib.cgemm_4bit_inference_naive_fp32(
- m,
- n,
- k,
- get_ptr(A),
- get_ptr(B),
- get_ptr(absmax),
- get_ptr(state.code),
- get_ptr(out),
- lda,
- ldb,
- ldc,
- ct.c_int32(state.blocksize),
- stream,
- )
- else:
- raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
+ absmax = dequantize_blockwise(absmax, state.state2) + state.offset
- else:
- raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
+ if out is not None:
+ torch.ops.bitsandbytes.gemv_4bit.out(
+ A,
+ B,
+ state.shape,
+ absmax,
+ state.code,
+ state.blocksize,
+ out=out,
+ )
+ return out
- return out
+ return torch.ops.bitsandbytes.gemv_4bit.default(
+ A,
+ B,
+ state.shape,
+ absmax,
+ state.code,
+ state.blocksize,
+ )
def igemm(
@@ -2252,27 +1945,6 @@ def batched_igemm(
return out
-@deprecated(
- "igemmlt is deprecated and will be removed in a future release. Please use int8_linear_matmul instead.",
- category=FutureWarning,
-)
-def igemmlt(
- A: torch.Tensor,
- B: torch.Tensor,
- SA: Tuple[torch.Size, str],
- SB: Tuple[torch.Size, str],
- out: Optional[torch.Tensor] = None,
- Sout: Optional[Tuple[torch.Size, str]] = None,
- dtype=torch.int32,
-):
- if SA is not None and SA[1] != "row":
- raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`")
- if SB is not None and SB[1] != "row":
- raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`")
- result = int8_linear_matmul(A, B, out=out, dtype=dtype)
- return result, (result.shape, "row")
-
-
def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
"""Performs an 8-bit integer matrix multiplication.
@@ -2292,88 +1964,11 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten
Returns:
`torch.Tensor`: The result of the operation.
"""
+ if out is not None:
+ torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)
+ return out
- #
- # To use the IMMA tensor core kernels without special Turing/Ampere layouts,
- # cublasLt has some rules, namely: A must be transposed, B must not be transposed.
- # The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major.
- # This will typically be used with row-major tensors to efficiently
- # calculate the linear layer with `C = B @ A.T` without any transformations.
- # We will swap A and B in the API invocation, so that we get `C = A @ B.T`.
- #
- # Quick explanation:
- # With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`.
- # To get row-major output, `C.T = (A @ B.T).T = B @ A.T`.
- #
- A, B = B, A
-
- shapeA = A.shape
- shapeB = B.shape
-
- assert A.dtype == torch.int8
- assert B.dtype == torch.int8
- assert A.ndim == 2, "Only two dimensional matrices are supported for argument B"
- assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A"
- assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}"
- assert out is None or out.dtype == dtype
-
- shapeC = (*shapeB[:-1], shapeA[0])
-
- k, m = shapeA
- n = prod(shapeB[:-1])
- lda = shapeA[-1] # Weights (outputs, inputs)
- ldb = shapeB[-1] # Activations (batch, tokens, inputs)
- ldc = shapeC[-1] # Output (batch, tokens, outputs)
-
- assert (
- lda == ldb
- ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
-
- # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
- # We'll fall back to a slower fp32 calculation in this circumstance.
- # Fortunately, this should not be very common.
- if lda % 4 != 0:
- result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
- if out is not None:
- result = out.copy_(result)
- return result
-
- if out is None:
- out = torch.empty(shapeC, device=A.device, dtype=dtype)
-
- is_on_gpu([A, B, out])
-
- with _cuda_device_of(A):
- ctx = CUBLAS_Context.get_instance().get_context(A.device)
- ptrA = get_ptr(A)
- ptrB = get_ptr(B)
- ptrC = get_ptr(out)
- ptrRowScale = None
- m = ct.c_int32(m)
- n = ct.c_int32(n)
- k = ct.c_int32(k)
- lda = ct.c_int32(lda)
- ldb = ct.c_int32(ldb)
- ldc = ct.c_int32(ldc)
- stream = _get_tensor_stream(A)
-
- if dtype == torch.int32:
- has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
- else:
- has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
-
- if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
- raise NotImplementedError("int8_linear_matmul not implemented!")
-
- if has_error:
- raise RuntimeError(
- f"cublasLt ran into an error!\n"
- f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
- f"\t{(lda, ldb, ldc)=}\n"
- f"\t{(m, n, k)=}"
- )
-
- return out
+ return torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
def int8_mm_dequant(
@@ -2395,47 +1990,16 @@ def int8_mm_dequant(
Returns:
`torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`.
"""
+ result = torch.ops.bitsandbytes.int8_mm_dequant.default(A, row_stats, col_stats, dtype=torch.float16, bias=bias)
- assert A.dtype == torch.int32
-
- if bias is not None:
- assert bias.dtype == torch.float16
-
- if out is None:
- out = torch.empty_like(A, dtype=torch.float16)
-
- ptrA = get_ptr(A)
- ptrOut = get_ptr(out)
- ptrRowStats = get_ptr(row_stats)
- ptrColStats = get_ptr(col_stats)
- ptrBias = get_ptr(bias)
- numRows = ct.c_int32(prod(A.shape[:-1]))
- numCols = ct.c_int32(A.shape[-1])
-
- is_on_gpu([A, row_stats, col_stats, out, bias])
-
- with _cuda_device_of(A):
- lib.cdequant_mm_int32_fp16(
- ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
- )
+ # TODO(matthewdouglas): Deprecate out kwarg
+ if out is not None:
+ return out.copy_(result)
- return out
-
-
-@deprecated("mm_dequant is deprecated. Please use int8_mm_dequant() instead.", category=FutureWarning)
-def mm_dequant(
- A: torch.Tensor,
- quant_state: Optional[Tuple[torch.Size, str]], # Not used
- row_stats: torch.Tensor,
- col_stats: torch.Tensor,
- out: Optional[torch.Tensor] = None,
- new_row_stats=None, # Not used
- new_col_stats=None, # Not used
- bias: Optional[torch.Tensor] = None,
-):
- return int8_mm_dequant(A, row_stats, col_stats, out, bias)
+ return result
+@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_colrow_absmax(
A: torch.Tensor,
row_stats: Optional[torch.Tensor] = None,
@@ -2493,6 +2057,7 @@ def get_colrow_absmax(
return row_stats, col_stats, outlier_mask
+@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_row_absmax(A: torch.Tensor, threshold=0.0):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
@@ -2611,72 +2176,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
-@deprecated("This function is deprecated. Please use `int8_double_quant` instead.", category=FutureWarning)
-def double_quant(
- A: torch.Tensor,
- col_stats: Optional[torch.Tensor] = None,
- row_stats: Optional[torch.Tensor] = None,
- out_col: Optional[torch.Tensor] = None,
- out_row: Optional[torch.Tensor] = None,
- threshold=0.0,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[COOSparseTensor]]:
- """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
-
- The statistics are determined both row-wise and column-wise (transposed).
-
- For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
-
-
- This function exists for backwards compatibility only. It is advised to use [`int8_double_quant`] instead.
- The difference is that this function will return a [`COOSparseTensor`] for outliers instead of a column index.
-
-
- Args:
- A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
- col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
- row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
- out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
- out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
- threshold (`float`, *optional*):
- An optional threshold for sparse decomposition of outlier features.
-
- No outliers are held back when 0.0. Defaults to 0.0.
-
- Returns:
- `Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- - `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
- - `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
- - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
- - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
- - `COOSparseTensor`, *optional*: A structure representing the outlier values from the input tensor.
- """
-
- coo_tensor = None
- quant_row, quant_col, row_stats, col_stats, outlier_cols = int8_double_quant(
- A,
- col_stats,
- row_stats,
- out_col,
- out_row,
- threshold=threshold,
- )
-
- if threshold > 0.0 and outlier_cols is not None:
- # Build a COO tensor including all of the outlier columns.
- outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32)
- outliers = A[:, outlier_cols]
- coo_tensor = COOSparseTensor(
- A.shape[0],
- A.shape[1],
- outliers.numel(),
- outlier_rows.repeat_interleave(outliers.size(1)),
- outlier_cols.repeat(outliers.size(0)).int(),
- outliers,
- )
-
- return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor
-
-
def int8_double_quant(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
@@ -2716,23 +2215,16 @@ def int8_double_quant(
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
"""
- # TODO: Optimize/write CUDA kernel for this?
-
- # Use CUDA kernel for rowwise and COO tensor
- quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold)
-
- # PyTorch impl for colwise
- _, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
- if threshold > 0.0 and outlier_mask is not None:
- A = A.masked_fill(outlier_mask, 0.0)
- quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8)
-
- if out_row is not None:
- quant_row = out_row.copy_(quant_row)
+ if row_stats is not None:
+ raise ValueError("row_stats must be None. int8_double_quant() does not support pre-allocated row_stats.")
+ if col_stats is not None:
+ raise ValueError("col_stats must be None. int8_double_quant() does not support pre-allocated col_stats.")
if out_col is not None:
- quant_col = out_col.copy_(quant_col)
+ raise ValueError("out_col must be None. int8_double_quant() does not support pre-allocated out_col.")
+ if out_row is not None:
+ raise ValueError("out_row must be None. int8_double_quant() does not support pre-allocated out_row.")
- return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
+ return torch.ops.bitsandbytes.int8_double_quant.default(A, threshold=threshold)
def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
@@ -2746,7 +2238,7 @@ def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
`torch.Tensor` with dtype `torch.float32`: The dequantized tensor.
"""
# To dequantize we divide by 127, or multiply by the reciprocal.
- return A * stats.view(-1, 1) * 7.874015718698502e-3
+ return torch.ops.bitsandbytes.int8_vectorwise_dequant.default(A, stats)
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
@@ -2767,94 +2259,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
- `torch.Tensor` with dtype `torch.float32`: The quantization scales.
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
"""
-
- assert A.dtype == torch.half
- is_on_gpu([A])
-
- rows = prod(A.shape[:-1])
- cols = A.shape[-1]
-
- row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
- out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
-
- outlier_cols = None
-
- if threshold > 0.0:
- # TODO we could improve perf of this
- outliers = A.abs() >= threshold
-
- if outliers.any():
- outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
-
- with _cuda_device_of(A):
- lib.cint8_vector_quant(
- get_ptr(A),
- get_ptr(out_row),
- get_ptr(row_stats),
- ct.c_float(threshold),
- ct.c_int32(rows),
- ct.c_int32(cols),
- _get_tensor_stream(A),
- )
-
- # Zero out values from outlier columns across all rows.
- # The kernel will handle this for outliers themselves, so we can optimize for rows=1.
- if rows > 1 and outlier_cols is not None:
- out_row[:, outlier_cols] = 0
-
- return out_row, row_stats, outlier_cols
-
-
-@deprecated(
- "The layout transformation operations will be removed in a future release. Please use row-major layout only.",
- category=FutureWarning,
-)
-def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
- prev_device = pre_call(A.device)
- if state is None:
- state = (A.shape, from_order)
- else:
- from_order = state[1]
- if out is None:
- out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
- else:
- new_state = (state[0], to_order) # (shape, order)
-
- shape = state[0]
- if len(shape) == 2:
- dim1 = ct.c_int32(shape[0])
- dim2 = ct.c_int32(shape[1])
- else:
- dim1 = ct.c_int32(shape[0] * shape[1])
- dim2 = ct.c_int32(shape[2])
-
- is_on_gpu([A, out])
- if to_order == "col32":
- if transpose:
- lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
- else:
- lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == "col_turing":
- if transpose:
- lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
- else:
- lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == "col_ampere":
- if transpose:
- lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
- else:
- lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
- elif to_order == "row":
- if from_order == "col_turing":
- lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
- elif from_order == "col_ampere":
- lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
- else:
- raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}")
-
- post_call(prev_device)
-
- return out, new_state
+ return torch.ops.bitsandbytes.int8_vectorwise_quant.default(A, threshold)
def spmm_coo(
@@ -3059,7 +2464,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
@deprecated(
- "This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.",
+ "This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def vectorwise_dequant(xq, max1, quant_type="vector"):
@@ -3071,7 +2476,7 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):
@deprecated(
- "This function is deprecated and will be removed in a future release. Consider using `int8_mm_dequant` instead.",
+ "This function is deprecated and will be removed in a future release.",
category=FutureWarning,
)
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
@@ -3131,51 +2536,3 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
return x.to(dtype)
else:
return None
-
-
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
- offset = B.float().t().sum(0) * (SA[0] + SA[1])
- x = xq.float()
- if len(xq.shape) == 2 and len(SB.shape) == 3:
- SB = SB.squeeze(0)
- if len(SB.shape) == 2:
- x *= SB.t() / 127
- else:
- x *= SB / 127
- x *= SA[1] / 127
- x += offset
- return x.to(dtype)
-
-
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def extract_outliers(A, SA, idx):
- shapeA = SA[0]
- formatA = SA[1]
- assert formatA in ["col_turing", "col_ampere"]
- assert A.device.type == "cuda"
-
- out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
-
- idx_size = ct.c_int32(idx.numel())
- rows = ct.c_int32(shapeA[0])
- cols = ct.c_int32(shapeA[1])
- ptrA = get_ptr(A)
- ptrIdx = get_ptr(idx)
- ptrOut = get_ptr(out)
-
- prev_device = pre_call(A.device)
- if formatA == "col_turing":
- lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
- elif formatA == "col_ampere":
- lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
- post_call(prev_device)
-
- return out
-
-
-@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
-def pipeline_test(A, batch_size):
- out = torch.zeros_like(A)
- lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
- return out
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index e63cd8db9..f4d838d48 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -11,7 +11,6 @@
import torch.nn.functional as F
import bitsandbytes as bnb
-from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import (
@@ -654,8 +653,7 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]
if weight_format != "row":
- tile_indices = get_tile_inds(weight_format, weight.device)
- state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
+ raise ValueError(f"Only 'row' weight format is supported, got {weight_format}")
class Embedding8bit(nn.Embedding):
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index fdf1d02c0..22ee756d9 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -2205,333 +2205,6 @@ __global__ void kdequant_mm_int32_fp16(
}
}
-template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
-{
-
- // 0. Load data into 32*32 shared memory tiles
- // 1. transpose / reorder in shared memory
- // 2. store
-
- // COL32 FORMAT:
- // rows*32 tiles
-
- // TURING FORMAT:
- // 8*32 tiles with 4*4 subtiles
- // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
- // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
- // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
- // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
- // index increases by 32
-
- // AMPERE FORMAT:
- // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
- // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
- // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
-
-
- // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
- // As such we need:
- // at least 32*4 shared memory tiles for col32; preferably 32*32
- // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
- // at least 32*8 shared memory tiles for col4_turing: preferably 32*32
- // for efficient loading of row major we need to load 128 elements and repeat this 32 items
- // this would imply a 32x128 shared memory tile -> 4kb
- // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
- // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
- // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
- // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
- //
- // to make the shared memory work with that occupancy we might need to union the block loads/stores
-
- // each block loads TILE_COLs columns and TILE_ROW rows
- // after reading a tile the row counter increase by TILE_ROWS
- // the col counter reset after reading TILE_COL elements
- const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
- // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
- const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
- const int base_idx = (base_row*cols) + base_col;
-
- // we load 128 bytes per warp with
- // 32 rows for transposes that fill col32 types
- // so that we can have contiguous stores
- __shared__ char smem_data[32*33*ITEMS_PER_THREAD];
- char local_data[ITEMS_PER_THREAD];
- typedef cub::BlockExchange BlockExchange;
-
- // we load row after row from the base_position
- // Load data row by row
- int warps = blockDim.x/32;
- int warp_id = threadIdx.x/32;
- int warp_lane = threadIdx.x % 32;
- int offset = 0;
-
- int smem_row = 0;
- // each warp loads one row of 128 bytes
- for(int row = warp_id; row < TILE_ROWS; row+=warps)
- {
- int i = base_idx + (row*cols);
- // we load up to 128 bytes/items per load
- int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;
-
- // 0. Load data into 32*32 shared memory tiles
- if(base_row + row < rows)
- {
- #pragma unroll ITEMS_PER_THREAD
- for(int j = 0; j < ITEMS_PER_THREAD; j++)
- {
- int col_idx = warp_lane+(j*32);
- if(col_idx < valid_items)
- local_data[j] = A[i+col_idx];
- else
- local_data[j] = 0;
- }
- }
- else
- {
- #pragma unroll ITEMS_PER_THREAD
- for(int j = 0; j < ITEMS_PER_THREAD; j++)
- local_data[j] = 0;
- }
-
- if(TRANSPOSE)
- {
- #pragma unroll ITEMS_PER_THREAD
- for(int j = 0; j < ITEMS_PER_THREAD; j++)
- {
- int local_col = (32*j)+warp_lane;
- //int local_row = row;
- // store as 256x32
- smem_data[(local_col*33) + row] = local_data[j];
- }
- }
- else
- {
- // treat smem as 32x256, that is 32 rows and 256 columns
- #pragma unroll ITEMS_PER_THREAD
- for(int j = 0; j < ITEMS_PER_THREAD; j++)
- smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
- }
-
-
-
- smem_row += warps;
-
- // 1. transpose / reorder in shared memory
- if(smem_row % 32 == 0)
- {
- smem_row = 0;
- __syncthreads();
-
- for(int subrow = warp_id; subrow < 32; subrow+=warps)
- {
- for(int j = 0; j < ITEMS_PER_THREAD; j++)
- {
-
- switch(FORMAT)
- {
- case COL32:
- if(TRANSPOSE)
- {
- // data lies in shared memory in the following way:
- // row0 [col0 col1 ... col31]
- // row1 [col0 col1 ... col31]
- // ...
- //
- // As such we read consecutive entries with 256 threads (8rows x 32 columns)
- // as j increase, the row increase by a factor of 8
- // We load 8 rows per subrow loop, and subrow increase by 8 per loop
- // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
- //const int local_row = warp_id; // each warp_id is one row
- //const int block_row = base_col; // block offset for row
- //const int local_col = warp_lane
- //const int global_col = base_row; // block offset for col
- if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
- {
- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
-
- // each 32 columns we have new tile
- // each tile has size outRows*32 and base_row is done in increments of 32
- offset = base_row*outRows;
- out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
- }
- }
- else
- {
- if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
- {
- offset = (base_col/32)*(32*rows);
- char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
- out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
- }
- }
- break;
- case COL_TURING:
- // TURING FORMAT:
- // 8*32 tiles with 4*4 subtiles
- // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
- // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
- // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
- // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
- // index increases by 32
- //
- // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
- if(TRANSPOSE)
- {
- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
- //const int local_row = warp_id; // each warp_id is one row
- //const int block_row = base_col; // block offset for row
- //const int local_col = warp_lane
- //const int global_col = base_row; // block offset for col
- if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
- {
- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
-
- // each 32 columns we have new tile
- // each tile has size 8*32 = 256 elements offset
- // for each row offset of 8 we increaes the tile first
- // after all rows are exhausted, we increase the col
- int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
-
- // we increase by row_tile_column every 32 columns
- // base_row increase in increments of 32
- //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
- //int col_offset = (base_row/32)*row_tile_column;
- // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
- // 256*outRows/8*base_row/32 = outRows*base_row
- int col_offset = outRows*base_row;
-
- offset = row_offset+col_offset;
-
- // since we process even number of rows with each j (8) and with each subrow (8j) we can determine
- // odd or even rows with the warp_id (each warp processes one row)
- // the col is warp_lane (max 32 columns per row) and the row warp_id
- if(warp_id % 2 == 1)
- // odd
- offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
- else
- // even
- offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);
-
- out[offset] = data;
- }
- }
- else
- {
- if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
- {
- char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
- // set offset designates the tile offset among the 8*32 tiles
- // we first increase rows and then columns. Since we load 128 columns at once
- // we increase the offset by outRows*32 every 32 columns
- // additionally, we increase the offset by 8*32=256 every 8 rows
- offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
- // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
- // each of these has 32 values in total for 32*4 = 128 as offset if odd
- // every set of 4 columns increases the total offset by 16
- // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
- // this happens every 8 rows anew (subrow % 8)
- // one writes 4 columns at once that is (col % 4) for the particular index in the subtile
- int subcol = warp_lane;
-
- // add local offset (4x4 sub-tile)
- if(subrow % 2 == 1)
- // odd
- offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
- else
- // even
- offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);
-
- out[offset] = data;
- }
- }
- break;
- case COL_AMPERE:
- // AMPERE FORMAT:
- // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
- // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
- // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
- if(TRANSPOSE)
- {
- const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
- const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
- //const int local_row = warp_id; // each warp_id is one row
- //const int block_row = base_col; // block offset for row
- //const int local_col = warp_lane
- //const int global_col = base_row; // block offset for col
- if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
- {
- // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem
- char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
-
- // each 32 columns we have new tile
- // each tile has size 32*32 = 1024 elements offset
- // for each row offset of 32 we increaes the tile first
- // after all rows are exhausted, we increase the col
- int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
-
- // we increase by row_tile_column every 32 columns
- // base_row increase in increments of 32
- //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
- //int col_offset = (base_row/32)*row_tile_column;
- // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
- // 1024*outRows/32*base_row/32 = outRows*base_row
- int col_offset = outRows*base_row;
-
- offset = row_offset+col_offset;
-
-
- // same as in the non-transpose case (see below)
- // the difference is that now rows = cols
- // in this case warp_id = subrow
-
- // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
- // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
- // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
- // every 2 rows, the offset increases by two [0, 1, 8, 9...]
- // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
- int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
- int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);
-
- // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
- out[offset + (ampere_row*32) + warp_lane] = data;
- }
- }
- else
- {
- if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
- {
- char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
-
- // set offset designates the tile offset among the 32*32 tiles
- // we first increase rows and then columns. Since we load 128 columns at once
- // we increase the offset by outRows*32 every 32 columns
- // additionally, we increase the offset by 32*32=1024 every 32 rows
- offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)
-
- // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
- // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
- // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
- // every 2 rows, the offset increases by two [0, 1, 8, 9...]
- // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
- int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);
-
- // global offset + row with 32 cols each + 32 cols per j + col_idx
- out[offset + (local_row*32) + warp_lane] = data;
- }
- }
- break;
- }
- }
- }
- }
- }
-}
-
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
@@ -2679,69 +2352,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}
-template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
-{
- int local_colidx = idx[blockIdx.x];
-
- if(FORMAT==COL_TURING)
- {
- // TURING FORMAT:
- // 8*32 tiles with 4*4 subtiles
- // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
- // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
- // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
- // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
- // index increases by 32
- // columns are grouped in increments of 4, meaning that one has the following rows and columns
- // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
- // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]
-
- // each thread reads 1 element = 1 row
- for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
- {
- int offset_per_col_tile = ((rowsA+7)/8)*32*8;
- int tile_offset_rows = (row/8)*32*8;
- int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
- int offset = 0;
- int subtile_col_idx = local_colidx%32;
- int subtile_row_idx = row % 8;
- if(row % 2 == 1)
- offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2);
- else
- // even
- offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2);
-
- offset += tile_offset_rows + tile_offset_cols;
-
- char val = A[offset];
-
- int out_idx = (row*idx_size) + blockIdx.x;
- out[out_idx] = val;
- }
- }
- else if(FORMAT == COL_AMPERE)
- {
-
- for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
- {
- // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
- // within each tile.
- int offset_per_col_tile = ((rowsA+31)/32)*32*32;
- int tile_offset_rows = (row/32)*32*32;
- int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
- int subtile_col_idx = local_colidx%32;
- int subtile_row_idx = row % 32;
- // this magic is taken from the cublasLt doc (search for COL32)
- int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
- offset += tile_offset_cols + tile_offset_rows;
-
- char val = A[offset];
- int out_idx = (row*idx_size) + blockIdx.x;
- out[out_idx] = val;
- }
- }
-}
-
#define WARPS 3
template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{
@@ -3376,9 +2986,6 @@ template __global__ void kgemm_4bit_inference_naive(int M, int N,
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
-template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
-template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
-
template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
@@ -3386,13 +2993,6 @@ template __global__ void kspmm_coo_very_sparse_naive(int *max
template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-
template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index 18017c4d2..a701481d3 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -121,8 +121,6 @@ template __global__ void kInt8Vector
template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
-
template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
diff --git a/csrc/ops.cu b/csrc/ops.cu
index e6c2bb443..775984553 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -374,48 +374,6 @@ template int get_leading_dim(int dim1, int dim2)
}
}
-template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
-{
- cublasLtOrder_t orderA = get_order();
- cublasLtOrder_t orderOut = get_order();
- int ldA = get_leading_dim(dim1, dim2);
- int ldOut = get_leading_dim(dim1, dim2);
-
- cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
- cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
- cublasOperation_t opTranspose = CUBLAS_OP_T;
- float transformAlpha = 1.0f, transformBeta = 0.0f;
-
-
- if(DTYPE == 8)
- {
- checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
- checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
- }
- else if(DTYPE == 32)
- {
- checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
- checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
- }
- else
- {
- printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
- }
-
- checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
- checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
-
- checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
-
- if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
-
- checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
-
- if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
- if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
- if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
-}
-
template int igemmlt(
cublasLtHandle_t ltHandle,
int m, int n, int k,
@@ -542,50 +500,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
-template void transformRowToFormat(char * A, char *out, int rows, int cols)
-{
- int threads = 256;
- int items_per_thread = 8;
- // we load 128 column values per warp
- int tile_cols = 32*items_per_thread;
- int tile_rows = 32;
- int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
- int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
- int row_tiles = (tiledRows/tile_rows);
- int col_tiles = (tiledCols/tile_cols);
- row_tiles = row_tiles > 0 ? row_tiles : 1;
- col_tiles = col_tiles > 0 ? col_tiles : 1;
- int num_blocks = row_tiles * col_tiles;
-
- int outCols = fill_up_to_nearest_multiple(cols, 32);
- int outRows = fill_up_to_nearest_multiple(rows, 32);
- if(FORMAT == COL_TURING)
- {
- if(TRANSPOSE)
- outRows = fill_up_to_nearest_multiple(cols, 8);
- else
- outRows = fill_up_to_nearest_multiple(rows, 8);
- }
- else if(FORMAT == COL_AMPERE)
- {
- if(TRANSPOSE)
- outRows = fill_up_to_nearest_multiple(cols, 32);
- else
- outRows = fill_up_to_nearest_multiple(rows, 32);
- }
- else
- {
- if(TRANSPOSE)
- {
- outCols = fill_up_to_nearest_multiple(rows, 32);
- outRows = cols;
- }
- }
-
- kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols);
- CUDA_CHECK_RETURN(cudaPeekAtLastError());
-}
-
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
cusparseSpMatDescr_t descA;
@@ -643,32 +557,6 @@ template void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
-
-template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
-{
- int threads = 256;
- // we load 128 column values per warp
- int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
- int tiledRows = 0;
-
- int num_blocks = idx_size;
-
- if(FORMAT == COL_TURING)
- {
- tiledRows = fill_up_to_nearest_multiple(rows, 8);
- }
- else if(FORMAT == COL_AMPERE)
- {
- tiledRows = fill_up_to_nearest_multiple(rows, 32);
- }
-
- kExtractOutliers<<>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
- CUDA_CHECK_RETURN(cudaPeekAtLastError());
-}
-
-
-
-
template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{
@@ -722,8 +610,6 @@ template void gemm_4bit_inference_naive(int m, int n, int k, float *
//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
-template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
-template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
@@ -732,13 +618,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons
template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
-template void transformRowToFormat(char * A, char *out, int rows, int cols);
-template void transformRowToFormat(char * A, char *out, int rows, int cols);
-template void transformRowToFormat(char * A, char *out, int rows, int cols);
-template void transformRowToFormat(char * A, char *out, int rows, int cols);
-template void transformRowToFormat(char * A, char *out, int rows, int cols);
-template void transformRowToFormat(char * A, char *out, int rows, int cols);
-
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
@@ -840,15 +719,6 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
-template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
-template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
-template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
-template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
-template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
-template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
-template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
-template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
-
template int get_leading_dim(int dim1, int dim2);
template int get_leading_dim(int dim1, int dim2);
template int get_leading_dim(int dim1, int dim2);
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 1170237e1..48a6a3c74 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -173,20 +173,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream);
-template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream);
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream);
-template void transformRowToFormat(char * A, char *out, int rows, int cols);
-
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
-template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
-
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp
index 0ced0394c..56bec82e8 100644
--- a/csrc/pythonInterface.cpp
+++ b/csrc/pythonInterface.cpp
@@ -149,32 +149,6 @@ void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv
void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); }
void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); }
-
-#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
-void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
-{ \
- transform(ltHandle, A, out, dim1, dim2); \
-} \
-
-MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
-MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
-MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
-MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
-MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
-MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
-MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
-MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
-
-void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); }
-void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); }
-void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); }
-void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); }
-void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); }
-void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); }
-
-void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); }
-void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); }
-
int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
@@ -317,22 +291,6 @@ extern "C"
int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
}
-
- #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
- void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
- { \
- transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
- } \
-
- MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
- MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
- MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
- MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
- MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8)
- MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
- MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
- MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
-
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); }
void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) {
@@ -342,24 +300,6 @@ extern "C"
int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream);
}
- void ctransform_row2col32(char * A, char *out, int rows, int cols)
- { transform_row2col32(A, out, rows, cols); }
-
- void ctransform_row2col32T(char * A, char *out, int rows, int cols)
- { transform_row2col32T(A, out, rows, cols); }
-
- void ctransform_row2turing(char * A, char *out, int rows, int cols)
- { transform_row2turing(A, out, rows, cols); }
-
- void ctransform_row2turingT(char * A, char *out, int rows, int cols)
- { transform_row2turingT(A, out, rows, cols); }
-
- void ctransform_row2ampere(char * A, char *out, int rows, int cols)
- { transform_row2ampere(A, out, rows, cols); }
-
- void ctransform_row2ampereT(char * A, char *out, int rows, int cols)
- { transform_row2ampereT(A, out, rows, cols); }
-
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{ spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
@@ -369,9 +309,6 @@ extern "C"
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
- void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
- void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
-
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx
index 1eea3247b..6fa8c3b29 100644
--- a/docs/source/installation.mdx
+++ b/docs/source/installation.mdx
@@ -174,7 +174,7 @@ export BNB_CUDA_VERSION=126
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-12.6
```
-3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded.
+3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 12.6) and a different bitsandbytes library is loaded.
## Multi-backend Support (Alpha Release)[[multi-backend]]
diff --git a/pyproject.toml b/pyproject.toml
index 6e5c6dde3..f4ae66a8e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
- "torch>=2.0,<3",
+ "torch>=2.2,<3",
"numpy>=1.17"
]
diff --git a/tests/helpers.py b/tests/helpers.py
index 02cb881a3..de11f4f66 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -36,6 +36,8 @@ def format_with_label(label: str, value: Any) -> str:
formatted = "T" if value else "F"
elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value):
formatted = "".join("T" if b else "F" for b in value)
+ elif isinstance(value, torch.dtype):
+ formatted = describe_dtype(value)
else:
formatted = str(value)
return f"{label}={formatted}"
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index ae2529542..4b93ebcbe 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -1,12 +1,9 @@
-from typing import Tuple
-
import pytest
import torch
import bitsandbytes as bnb
from tests.helpers import (
BOOLEAN_TRIPLES,
- BOOLEAN_TUPLES,
TRUE_FALSE,
describe_dtype,
get_test_dims,
@@ -16,189 +13,6 @@
TRANSPOSE_VALS = [(False, True), (False, False)]
-@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2"))
-@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3"))
-@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4"))
-@pytest.mark.parametrize(
- "funcs",
- [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)],
- ids=["func=bmm", "func=matmul"],
-)
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
-@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad"))
-@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
-@pytest.mark.deprecated
-def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]):
- if dim2 > 0:
- dim2 = dim2 - (dim2 % 16)
- dim3 = dim3 - (dim3 % 16)
- dim4 = dim4 - (dim4 % 16)
- for i in range(25):
- # normal multiply
- if funcs[0] in [torch.mm, torch.matmul]:
- dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
- dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
- A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
- B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
- target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1])
- torch.nn.init.xavier_uniform_(B)
-
- if not transpose[0] and not transpose[1]:
- out_torch = funcs[0](A, B)
- out_bnb = funcs[1](A, B)
- elif not transpose[0] and transpose[1]:
- out_torch = funcs[0](A, B.t())
- out_bnb = funcs[1](A, B.t())
- elif transpose[0] and not transpose[1]:
- out_torch = funcs[0](A.t(), B)
- out_bnb = funcs[1](A.t(), B)
- elif transpose[0] and transpose[1]:
- out_torch = funcs[0](A.t(), B.t())
- out_bnb = funcs[1](A.t(), B.t())
-
- n = out_bnb.numel()
- idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx == 0).sum().item() < n * 0.0175
- idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx == 0).sum().item() < n * 0.001
-
- if any(req_grad):
- out_bnb.data.copy_(out_torch)
- torch.cuda.synchronize()
- loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
- loss_bnb.backward()
- gradA1 = A.grad
- gradB1 = B.grad
- A.grad = None
- B.grad = None
-
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
- loss_torch.backward()
- gradA2 = A.grad
- gradB2 = B.grad
- A.grad = None
- B.grad = None
-
- if req_grad[0]:
- torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
- if req_grad[1]:
- n = gradB1.numel()
- idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx == 0).sum().item() < n * 0.1
- idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx == 0).sum().item() < n * 0.02
- torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)
-
- # batched matrix multiply
- if funcs[0] in [torch.bmm, torch.matmul]:
- A = torch.randn(
- size=(dim1, dim2, dim3),
- device="cuda",
- requires_grad=req_grad[0],
- )
- B = torch.randn(
- size=(dim1, dim3, dim4),
- device="cuda",
- requires_grad=req_grad[1],
- )
- target = torch.randn(
- size=(dim1, dim2, dim4),
- device="cuda",
- requires_grad=req_grad[1],
- )
- torch.nn.init.xavier_uniform_(B)
-
- out_torch = funcs[0](A, B)
- out_bnb = funcs[1](A, B)
-
- n = out_bnb.numel()
- idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx == 0).sum().item() < n * 0.01
- torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2)
-
- if any(req_grad):
- out_bnb.data.copy_(out_torch)
- torch.cuda.synchronize()
- loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
- loss_bnb.backward()
- gradA1 = A.grad
- gradB1 = B.grad
- A.grad = None
- B.grad = None
-
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
- loss_torch.backward()
- gradA2 = A.grad
- gradB2 = B.grad
- A.grad = None
- B.grad = None
-
- if req_grad[0]:
- torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
- if req_grad[1]:
- n = gradB1.numel()
- idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx == 0).sum().item() < n * 0.1
- idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx == 0).sum().item() < n * 0.02
-
- if funcs[0] in [torch.matmul]:
- dim1 = dim1 - (dim1 % 16)
- A = torch.randn(
- size=(dim1, dim2, dim3),
- device="cuda",
- requires_grad=req_grad[0],
- )
- dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
- B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
- target = torch.randn(
- size=(dim1, dim2, dim4),
- device="cuda",
- requires_grad=req_grad[1],
- )
- torch.nn.init.xavier_uniform_(B)
-
- if transpose[1]:
- out_torch = funcs[0](A, B.t())
- out_bnb = funcs[1](A, B.t())
- else:
- out_torch = funcs[0](A, B)
- out_bnb = funcs[1](A, B)
-
- n = out_bnb.numel()
- idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx == 0).sum().item() < n * 0.0175
- idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx == 0).sum().item() < n * 0.001
-
- if any(req_grad):
- out_bnb.data.copy_(out_torch)
- torch.cuda.synchronize()
- loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
- loss_bnb.backward()
- gradA1 = A.grad
- gradB1 = B.grad
- A.grad = None
- B.grad = None
-
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
- loss_torch.backward()
- gradA2 = A.grad
- gradB2 = B.grad
- A.grad = None
- B.grad = None
-
- if req_grad[0]:
- torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
- if req_grad[1]:
- n = gradB1.numel()
- idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx == 0).sum().item() < n * 0.1
- idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx == 0).sum().item() < n * 0.02
-
-
@pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py
new file mode 100644
index 000000000..9872cdfca
--- /dev/null
+++ b/tests/test_deprecated.py
@@ -0,0 +1,123 @@
+import numpy as np
+import pytest
+from scipy.stats import norm
+import torch
+
+from bitsandbytes import functional as F
+
+
+@pytest.mark.deprecated
+def test_kbit_quantile_estimation():
+ for i in range(100):
+ data = torch.randn(1024, 1024, device="cuda")
+ for bits in range(2, 9):
+ p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
+ val1 = torch.Tensor(norm.ppf(p)).cuda()
+ val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
+ err = torch.abs(val1 - val2).mean()
+ assert err < 0.038
+
+ for i in range(100):
+ data = torch.randn(1024, 1024, device="cuda")
+ for bits in range(2, 4):
+ total_values = 2**bits - 1
+ p = np.linspace(0, 1, 2 * total_values + 1)
+ idx = np.arange(1, 2 * total_values + 1, 2)
+ p = p[idx]
+ offset = 1 / (2 * total_values)
+ p = np.linspace(offset, 1 - offset, total_values)
+ val1 = torch.Tensor(norm.ppf(p)).cuda()
+ val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
+ err = torch.abs(val1 - val2).mean()
+ assert err < 0.035
+
+
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
+@pytest.mark.deprecated
+def test_estimate_quantiles(dtype):
+ A = torch.rand(1024, 1024, device="cuda")
+ A = A.to(dtype)
+ code = F.estimate_quantiles(A)
+
+ percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
+ torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
+
+ A = torch.randn(1024, 1024, device="cuda")
+ A = A.to(dtype)
+ code = F.estimate_quantiles(A)
+
+ quantiles = torch.quantile(A.float(), percs)
+ diff = torch.abs(code - quantiles)
+ assert (diff > 5e-02).sum().item() == 0
+
+
+@pytest.mark.deprecated
+def test_quantile_quantization():
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ code = F.estimate_quantiles(A1)
+ C = F.quantize_no_absmax(A1, code)
+ A2 = F.dequantize_no_absmax(C, code)
+ diff = torch.abs(A1 - A2).mean().item()
+ assert diff < 0.0075
+
+ A1 = torch.rand(1024, 1024, device="cuda")
+ code = F.estimate_quantiles(A1)
+ C = F.quantize_no_absmax(A1, code)
+ A2 = F.dequantize_no_absmax(C, code)
+ diff = torch.abs(A1 - A2).mean().item()
+ torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
+ assert diff < 0.001
+
+
+@pytest.mark.deprecated
+def test_dynamic_quantization():
+ diffs = []
+ reldiffs = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ C, S = F.quantize(A1)
+ A2 = F.dequantize(C, S)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ assert diff.mean().item() < 0.0135
+ print(sum(diffs) / len(diffs))
+ print(sum(reldiffs) / len(reldiffs))
+
+ for i in range(100):
+ A1 = torch.rand(1024, 1024, device="cuda")
+ C, S = F.quantize(A1)
+ A2 = F.dequantize(C, S)
+ diff = torch.abs(A1 - A2).mean().item()
+ torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
+ assert diff < 0.004
+
+
+@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
+@pytest.mark.deprecated
+def test_percentile_clipping(gtype):
+ gnorm_vec1 = torch.zeros(100, device="cuda")
+ gnorm_vec2 = torch.zeros(100, device="cuda")
+ n = 4
+ step = 0
+ percentile = 5
+ for i in range(20):
+ step += 1
+ g = torch.randn(n, n, dtype=gtype, device="cuda")
+ gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
+ assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
+
+ gnorm2 = torch.norm(g.float())
+ if step == 1:
+ gnorm_vec1[:] = gnorm2
+ else:
+ gnorm_vec1[step % 100] = gnorm2
+
+ vals, idx = torch.sort(gnorm_vec1)
+ clip1 = vals[percentile]
+
+ torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
+ torch.testing.assert_close(clip1, clip2)
+ torch.testing.assert_close(gnorm1, gnorm2)
diff --git a/tests/test_functional.py b/tests/test_functional.py
index c8ac20896..95d5cd6dc 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,4 +1,3 @@
-from itertools import product
import math
import random
import time
@@ -6,7 +5,6 @@
import einops
import numpy as np
import pytest
-from scipy.stats import norm
import torch
import bitsandbytes as bnb
@@ -88,77 +86,194 @@ def reset(self):
print("Resetting benchmark data")
-def setup():
- pass
-
-
-def teardown():
- pass
-
-
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
-def test_estimate_quantiles(dtype):
- A = torch.rand(1024, 1024, device="cuda")
- A = A.to(dtype)
- code = F.estimate_quantiles(A)
-
- percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
- torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
-
- A = torch.randn(1024, 1024, device="cuda")
- A = A.to(dtype)
- code = F.estimate_quantiles(A)
-
- quantiles = torch.quantile(A.float(), percs)
- diff = torch.abs(code - quantiles)
- assert (diff > 5e-02).sum().item() == 0
+class Test8BitBlockwiseQuantizeFunctional:
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+ @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
+ @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
+ @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
+ def test_dynamic_blockwise_quantization(self, dtype, nested, blocksize, signed):
+ diffs = []
+ reldiffs = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
+ C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
+ A2 = F.dequantize_blockwise(C, S)
+ diff = torch.abs(A1 - A2).float()
+ reldiff = diff / torch.abs(A1.float() + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ abserr = sum(diffs) / len(diffs)
+ relerr = sum(reldiffs) / len(reldiffs)
+ # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
+ # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
+ assert abserr < 0.011
+ assert relerr < 0.018
+ assert A2.dtype == dtype
+
+ diffs = []
+ code = F.create_dynamic_map(signed=signed)
+ for i in range(100):
+ A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
+ C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
+ A2 = F.dequantize_blockwise(C, S)
+ diff = torch.abs(A1 - A2).float()
+ reldiff = diff / torch.abs(A1.float() + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
+ abserr = sum(diffs) / len(diffs)
+ relerr = sum(reldiffs) / len(reldiffs)
+ if signed:
+ assert abserr < 0.0035
+ assert relerr < 0.015
+ else:
+ assert abserr < 0.00175
+ assert relerr < 0.012
+ assert A2.dtype == dtype
+ # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
+ # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
+
+ def test_blockwise_cpu_large(self):
+ diffs = []
+ reldiffs = []
+ batch = 128
+ seq = 128
+ for hidden in [128]: # , 14336]:
+ for blocksize in [4096, 16384]:
+ for i in range(2):
+ A1 = torch.randn(batch, seq, hidden, device="cpu")
+ t0 = time.time()
+ C, S = F.quantize_blockwise(A1, blocksize=blocksize)
+ A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
+ print(time.time() - t0)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ assert diffs[-1] < 0.011
+ # print(sum(diffs)/len(diffs))
+ # print(sum(reldiffs)/len(reldiffs))
+
+ @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
+ @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"])
+ def test_few_bit_quant(self, bits, method):
+ abserrs = []
+ relerrs = []
+ code = None
+ if method == "linear":
+ code = F.create_linear_map(True, total_bits=bits).cuda()
+ elif method == "fp8":
+ ebits = math.ceil(bits / 2)
+ pbits = bits - ebits - 1
+ code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
+ elif method == "dynamic":
+ code = F.create_dynamic_map(True, bits - 0, bits).cuda()
+ elif method == "quantile":
+ values = torch.randn(2048, 2048, device="cuda")
+ code = F.create_quantile_map(values, bits).cuda()
+ # for some data types we have no zero
+ # for some data types we have one zero
+ # for some data types we have two zeros
+ assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
+ # print(method, (code==0).sum())
+ assert code.numel() == 256
+ for i in range(10):
+ values = torch.randn(1, 32, device="cuda")
+ values /= values.abs().max()
+ # values[values.abs() < 1e-6] += 1e-5
+
+ q1 = []
+ v1 = []
+ for v in values[0]:
+ idx = torch.abs(v - code).argmin()
+ q1.append(idx.item())
+ v1.append(code[idx].item())
+
+ q1 = torch.Tensor(q1).cuda()
+ v1 = torch.Tensor(v1).cuda()
+
+ q2, S2 = F.quantize_blockwise(values, code=code)
+ v2 = F.dequantize_blockwise(q2, S2)
+
+ idx = torch.isclose(q1.int(), q2.int())
+ err2 = torch.abs(v2 - values)
+ abserrs.append(err2.mean().item())
+ relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
+ if idx.sum():
+ # some weird cases
+ err1 = torch.abs(v1 - values).mean()
+ # assert err2.mean() <= err1
+ else:
+ torch.testing.assert_close(q1, q2)
+
+ def test_fp8_quant(self):
+ for e_bits in range(1, 7):
+ p_bits = 7 - e_bits
+ code = F.create_fp8_map(True, e_bits, p_bits).cuda()
+
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1, code=code)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ # assert diff < 0.0075
+ # print(sum(abserr)/len(abserr))
+ # print(sum(relerr)/len(relerr))
+
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.rand(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1, code=code)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ # assert diff < 0.0075
+ # print(sum(abserr)/len(abserr))
+ # print(sum(relerr)/len(relerr))
+
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ # assert diff < 0.0075
+ # print(3, sum(abserr)/len(abserr))
+ # print(3, sum(relerr)/len(relerr))
+
+ @pytest.mark.benchmark
+ def test_bench_dequantization(self):
+ a = torch.rand(1024, 1024, device="cuda").half()
+ code = F.create_fp8_map(True, 3, 0, 4).cuda()
+ qa, SA = F.quantize_blockwise(a, code=code)
+ print(qa.max())
+
+ max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
+ # print(max_theoretical_mu)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ qa, SA = F.quantize_blockwise(a)
+ torch.cuda.synchronize()
+ # print((time.time()-t0)/1e6)
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
-@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
-@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
-@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
-def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed):
- diffs = []
- reldiffs = []
- for i in range(100):
- A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
- C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
- A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1 - A2).float()
- reldiff = diff / torch.abs(A1.float() + 1e-8)
- diffs.append(diff.mean().item())
- reldiffs.append(reldiff.mean().item())
- abserr = sum(diffs) / len(diffs)
- relerr = sum(reldiffs) / len(reldiffs)
- # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs))
- # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs))
- assert abserr < 0.011
- assert relerr < 0.018
- assert A2.dtype == dtype
-
- diffs = []
- code = F.create_dynamic_map(signed=signed)
- for i in range(100):
- A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype)
- C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code)
- A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1 - A2).float()
- reldiff = diff / torch.abs(A1.float() + 1e-8)
- diffs.append(diff.mean().item())
- reldiffs.append(reldiff.mean().item())
- # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
- abserr = sum(diffs) / len(diffs)
- relerr = sum(reldiffs) / len(reldiffs)
- if signed:
- assert abserr < 0.0035
- assert relerr < 0.015
- else:
- assert abserr < 0.00175
- assert relerr < 0.012
- assert A2.dtype == dtype
- # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
- # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
+def test_stable_embedding():
+ layer = bnb.nn.StableEmbedding(1024, 1024)
+ layer.reset_parameters()
def quant(x):
@@ -198,11 +313,6 @@ def quant_multi_chunk(x, dim, chunk_size=32):
return max1, x.to(torch.int8)
-def quant_minmax(A):
- minA = A.min()
- maxA = A.max()
-
-
def mean(xx):
return sum(xx) / float(len(xx))
@@ -219,531 +329,617 @@ def mean(xx):
}
-@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
-@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
-@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
-def test_approx_igemm(dim1, dim2, quant_methods, batched):
- dim1 = dim1 - (dim1 % 32)
- dim2 = dim2 - (dim2 % 32)
- errors = []
- relerrors = []
- # print("")
- for i in range(5):
- if batched:
- A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
- B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
- maxA, Ac = quant_methods[0](A, 2)
- maxB, Bc = quant_methods[1](B, 1)
- else:
- A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
- B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
- maxA, Ac = quant_methods[0](A, 1)
- maxB, Bc = quant_methods[1](B, 0)
- torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
- if batched:
- out2 = torch.bmm(A, B)
- C = torch.bmm(Ac.float(), Bc.float())
- else:
- out2 = torch.mm(A, B)
- C = F.igemm(Ac, Bc)
- out = quant_methods[4](maxA, maxB, C)
- std = out2.std()
- out /= std
- out2 /= std
- err = torch.abs(out - out2)
- relerr = err / torch.abs(out2)
- errors.append(err.mean().item())
- relerrors.append(relerr.mean().item())
- # print(mean(errors))
- # print(mean(relerrors))
-
-
-def test_stable_embedding():
- layer = bnb.nn.StableEmbedding(1024, 1024)
- layer.reset_parameters()
-
-
-@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
-@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
-@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
-@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
-def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
- hidden_dim = hidden_dim - (hidden_dim % 32)
- batch_dim = batch_dim - (batch_dim % 16)
- seq_dim = seq_dim - (seq_dim % 16)
- for i in range(k):
- shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
- shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
- A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
- if not transpose[0] and not transpose[1]:
- out2 = torch.matmul(A.float(), B.float())
- out = F.igemm(A, B)
- elif not transpose[0] and transpose[1]:
- out2 = torch.matmul(A.float(), B.t().float())
- out = F.igemm(A, B.t())
- elif transpose[0] and not transpose[1]:
- out2 = torch.matmul(A.t().float(), B.float())
- out = F.igemm(A.t(), B)
- elif transpose[0] and transpose[1]:
- out2 = torch.matmul(A.t().float(), B.t().float())
- out = F.igemm(A.t(), B.t())
-
- torch.testing.assert_close(out.float(), out2)
-
- for i in range(k):
- shapeA = (batch_dim, seq_dim, hidden_dim)
- shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
- A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
- if not transpose[0] and not transpose[1]:
- out2 = torch.matmul(A.float(), B.float())
- out = F.igemm(A, B)
- elif not transpose[0] and transpose[1]:
- out2 = torch.matmul(A.float(), B.t().float())
- out = F.igemm(A, B.t())
-
- torch.testing.assert_close(out.float(), out2)
-
-
-@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
-@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
-@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
-def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
- seq_dim = seq_dim - (seq_dim % 32)
- hidden_dim = hidden_dim - (hidden_dim % 32)
- batch_dim = batch_dim - (batch_dim % 2)
- for i in range(25):
- A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
- B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
- out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
- iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
- out = F.igemm(A, B, out=iout)
-
- torch.testing.assert_close(out.float(), out2)
-
-
-@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
-@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
-@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
-@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
-def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
- def min_max(x):
- maxA = torch.amax(x, dim=2, keepdim=True)
- minA = torch.amin(x, dim=2, keepdim=True)
- scale = (maxA - minA) / 2.0
- return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
-
- seq_dim = seq_dim - (seq_dim % 16)
- hidden_dim = hidden_dim - (hidden_dim % 16)
- batch_dim = batch_dim - (batch_dim % 2)
- errs = []
- relerrs = []
- errs2 = []
- relerrs2 = []
- for i in range(k):
- A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
- if transpose:
- B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
- else:
- B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
- Ac, minA, scale = min_max(A)
- if transpose:
- maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
- out = F.igemm(Ac, Bc.t())
- out2 = torch.matmul(A, B.t())
- offset = B.t().sum(0) * (minA + scale)
- out = out.float()
- out = (out * maxB.t() * scale / (127 * 127)) + offset
-
- maxA, Ac = quant_multi(A, dim=2)
- out3 = F.igemm(Ac, Bc.t())
- out3 = mm_dequant(maxA, maxB.t(), out3)
- else:
- maxB, Bc = quant_multi(B, dim=0)
- offset = B.sum(0) * (minA + scale)
- out = F.igemm(Ac, Bc)
- out2 = torch.matmul(A, B)
- out = out.float()
- out = (out * maxB * scale / (127 * 127)) + offset
-
- maxA, Ac = quant_multi(A, dim=2)
- out3 = F.igemm(Ac, Bc)
- out3 = mm_dequant(maxA, maxB, out3)
-
- std = out2.std()
- out2 /= std
- out /= std
- out3 /= std
-
- err = torch.abs(out - out2)
- relerr = err / (torch.abs(out2) + 1e-7)
-
- err2 = torch.abs(out3 - out2)
- relerr2 = err2 / (torch.abs(out2) + 1e-7)
-
- errs.append(err.mean().item())
- relerrs.append(relerr.mean().item())
- errs2.append(err2.mean().item())
- relerrs2.append(relerr2.mean().item())
- # print(mean(errs))
- # print(mean(relerrs))
- # print(mean(errs2))
- # print(mean(relerrs2))
- assert mean(errs) < 0.015
- assert mean(relerrs) < 0.3
-
-
-@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
-@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
-@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
-@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
-def test_ibmm(dim1, dim2, dim3, dim4, transpose):
- dim2 = dim2 - (dim2 % 16)
- dim3 = dim3 - (dim3 % 16)
- dim4 = dim4 - (dim4 % 16)
- for i in range(k):
- shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
- shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
- A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
-
- if not transpose[0] and not transpose[1]:
- out2 = torch.bmm(A.float(), B.float())
- out = F.igemm(A, B)
- elif not transpose[0] and transpose[1]:
- out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
- out = F.igemm(A, B.permute([0, 2, 1]))
- elif transpose[0] and not transpose[1]:
- out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
- out = F.igemm(A.permute([0, 2, 1]), B)
- elif transpose[0] and transpose[1]:
- out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
- out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
- torch.testing.assert_close(out.float(), out2.float())
-
-
-@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
-@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
-@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
-@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
-@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
-def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb):
- for i in range(k):
- if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
- elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
- B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
- C1 = torch.matmul(A.float(), B.t().float())
-
- C2 = F.int8_linear_matmul(A, B)
- torch.testing.assert_close(C1, C2.float())
-
-
-@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
-@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
-@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
-@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
-def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims):
- for i in range(k):
- if dims == 2:
- A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
- elif dims == 3:
- A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
- B = torch.randn((dim4, dim3), device="cuda").half()
- torch.nn.init.xavier_uniform_(B)
- C1 = torch.matmul(A, B.t())
-
- A = A.view(-1, A.shape[-1])
-
- CA, _, statsA, _, _ = F.int8_double_quant(A)
- CB, statsB, _ = F.int8_vectorwise_quant(B)
- output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
-
- torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
-
-
-@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
-@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
-@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
-def test_dequant_mm(dim1, dim4, dims, has_bias):
- inner = 128
- bias = None
- if has_bias:
- bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
-
- for i in range(1):
- A = torch.randn(dim1, inner, device="cuda")
- B = torch.randn(dim4, inner, device="cuda")
- C1 = torch.matmul(A.half(), B.t().half())
- if has_bias:
- C1 += bias
-
- A1, maxA = F.vectorwise_quant(A, dim=1)
- B1, maxB = F.vectorwise_quant(B, dim=1)
-
- C2 = F.int8_linear_matmul(A1, B1)
-
- C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
+class TestIGEMMFunctional:
+ @pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
+ @pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
+ @pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
+ def test_approx_igemm(self, dim1, dim2, quant_methods, batched):
+ dim1 = dim1 - (dim1 % 32)
+ dim2 = dim2 - (dim2 % 32)
+ errors = []
+ relerrors = []
+ # print("")
+ for i in range(5):
+ if batched:
+ A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
+ B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
+ maxA, Ac = quant_methods[0](A, 2)
+ maxB, Bc = quant_methods[1](B, 1)
+ else:
+ A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
+ B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
+ maxA, Ac = quant_methods[0](A, 1)
+ maxB, Bc = quant_methods[1](B, 0)
+ torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
+ if batched:
+ out2 = torch.bmm(A, B)
+ C = torch.bmm(Ac.float(), Bc.float())
+ else:
+ out2 = torch.mm(A, B)
+ C = F.igemm(Ac, Bc)
+ out = quant_methods[4](maxA, maxB, C)
+ std = out2.std()
+ out /= std
+ out2 /= std
+ err = torch.abs(out - out2)
+ relerr = err / torch.abs(out2)
+ errors.append(err.mean().item())
+ relerrors.append(relerr.mean().item())
+ # print(mean(errors))
+ # print(mean(relerrors))
+
+ @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim"))
+ @pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim"))
+ @pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim"))
+ @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
+ def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim):
+ hidden_dim = hidden_dim - (hidden_dim % 32)
+ batch_dim = batch_dim - (batch_dim % 16)
+ seq_dim = seq_dim - (seq_dim % 16)
+ for i in range(k):
+ shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
+ shapeB = (
+ (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
+ )
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.float(), B.t().float())
+ out = F.igemm(A, B.t())
+ elif transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.t().float(), B.float())
+ out = F.igemm(A.t(), B)
+ elif transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.t().float(), B.t().float())
+ out = F.igemm(A.t(), B.t())
+
+ torch.testing.assert_close(out.float(), out2)
+
+ for i in range(k):
+ shapeA = (batch_dim, seq_dim, hidden_dim)
+ shapeB = (
+ (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
+ )
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.float(), B.t().float())
+ out = F.igemm(A, B.t())
+
+ torch.testing.assert_close(out.float(), out2)
+
+ @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim"))
+ @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim"))
+ @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim"))
+ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
+ seq_dim = seq_dim - (seq_dim % 32)
+ hidden_dim = hidden_dim - (hidden_dim % 32)
+ batch_dim = batch_dim - (batch_dim % 2)
+ for i in range(25):
+ A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
+ out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
+ iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
+ out = F.igemm(A, B, out=iout)
+
+ torch.testing.assert_close(out.float(), out2)
+
+ @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim"))
+ @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim"))
+ @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim"))
+ @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
+ def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
+ def min_max(x):
+ maxA = torch.amax(x, dim=2, keepdim=True)
+ minA = torch.amin(x, dim=2, keepdim=True)
+ scale = (maxA - minA) / 2.0
+ return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
+
+ seq_dim = seq_dim - (seq_dim % 16)
+ hidden_dim = hidden_dim - (hidden_dim % 16)
+ batch_dim = batch_dim - (batch_dim % 2)
+ errs = []
+ relerrs = []
+ errs2 = []
+ relerrs2 = []
+ for i in range(k):
+ A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
+ if transpose:
+ B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
+ else:
+ B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
+ Ac, minA, scale = min_max(A)
+ if transpose:
+ maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
+ out = F.igemm(Ac, Bc.t())
+ out2 = torch.matmul(A, B.t())
+ offset = B.t().sum(0) * (minA + scale)
+ out = out.float()
+ out = (out * maxB.t() * scale / (127 * 127)) + offset
+
+ maxA, Ac = quant_multi(A, dim=2)
+ out3 = F.igemm(Ac, Bc.t())
+ out3 = mm_dequant(maxA, maxB.t(), out3)
+ else:
+ maxB, Bc = quant_multi(B, dim=0)
+ offset = B.sum(0) * (minA + scale)
+ out = F.igemm(Ac, Bc)
+ out2 = torch.matmul(A, B)
+ out = out.float()
+ out = (out * maxB * scale / (127 * 127)) + offset
+
+ maxA, Ac = quant_multi(A, dim=2)
+ out3 = F.igemm(Ac, Bc)
+ out3 = mm_dequant(maxA, maxB, out3)
+
+ std = out2.std()
+ out2 /= std
+ out /= std
+ out3 /= std
+
+ err = torch.abs(out - out2)
+ relerr = err / (torch.abs(out2) + 1e-7)
+
+ err2 = torch.abs(out3 - out2)
+ relerr2 = err2 / (torch.abs(out2) + 1e-7)
+
+ errs.append(err.mean().item())
+ relerrs.append(relerr.mean().item())
+ errs2.append(err2.mean().item())
+ relerrs2.append(relerr2.mean().item())
+ # print(mean(errs))
+ # print(mean(relerrs))
+ # print(mean(errs2))
+ # print(mean(relerrs2))
+ assert mean(errs) < 0.015
+ assert mean(relerrs) < 0.3
+
+ @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2"))
+ @pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3"))
+ @pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4"))
+ @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
+ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose):
+ dim2 = dim2 - (dim2 % 16)
+ dim3 = dim3 - (dim3 % 16)
+ dim4 = dim4 - (dim4 % 16)
+ for i in range(k):
+ shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
+ shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
+
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.bmm(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
+ out = F.igemm(A, B.permute([0, 2, 1]))
+ elif transpose[0] and not transpose[1]:
+ out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
+ out = F.igemm(A.permute([0, 2, 1]), B)
+ elif transpose[0] and transpose[1]:
+ out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
+ out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
+ torch.testing.assert_close(out.float(), out2.float())
+
+
+class TestLLMInt8Functional:
+ @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2"))
+ @pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3"))
+ @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4"))
+ @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims"))
+ @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb"))
+ def test_int8_linear_matmul(self, dim1, dim2, dim3, dim4, dims, ldb):
+ for i in range(k):
+ if dims == 2:
+ A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8)
+ elif dims == 3:
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
+ C1 = torch.matmul(A.float(), B.t().float())
+
+ C2 = F.int8_linear_matmul(A, B)
+ torch.testing.assert_close(C1, C2.float())
+
+ @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2"))
+ @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3"))
+ @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4"))
+ @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
+ def test_int8_linear_matmul_half(self, dim1, dim2, dim3, dim4, dims):
+ for i in range(k):
+ if dims == 2:
+ A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
+ elif dims == 3:
+ A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
+ B = torch.randn((dim4, dim3), device="cuda").half()
+ torch.nn.init.xavier_uniform_(B)
+ C1 = torch.matmul(A, B.t())
+
+ A = A.view(-1, A.shape[-1])
+
+ CA, _, statsA, _, _ = F.int8_double_quant(A)
+ CB, statsB, _ = F.int8_vectorwise_quant(B)
+ output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB)
+
+ torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
+
+ @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4"))
+ @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
+ @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
+ def test_dequant_mm(self, dim1, dim4, dims, has_bias):
+ inner = 128
+ bias = None
if has_bias:
- C4 += bias
-
- # TODO: is something wrong here? If so, the problem goes deeper
- # n = C1.numel()
- # p = 0.06
- std = C1.std(0).view(1, -1)
- C1 /= std
- C4 /= std
- # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
- # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
-
- C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias)
- C5 /= std
- torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
- n = C5.numel()
- assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
-
-
-@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
-@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
-@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
-def test_colrow_absmax(dim1, dim2, dims, threshold):
- for i in range(k):
- A = torch.randn(dim1, dim2, device="cuda").half()
-
- assert dims == 2
-
- row_stats1, _ = torch.abs(A.float()).max(1)
- col_stats1, _ = torch.abs(A.float()).max(0)
-
- if threshold > 0.0:
- A_truncated = A.clone()
- A_truncated[torch.abs(A_truncated) >= threshold] = 0.0
- row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
- col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
-
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
-
- nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten()
- nnz_block_ptr1 = torch.zeros(
- nnz_rows1_counts.shape[0] + 1,
- dtype=nnz_rows1_counts.dtype,
- device=nnz_rows1_counts.device,
+ bias = torch.randn(dim4, device="cuda", dtype=torch.float16)
+
+ for i in range(1):
+ A = torch.randn(dim1, inner, device="cuda")
+ B = torch.randn(dim4, inner, device="cuda")
+ C1 = torch.matmul(A.half(), B.t().half())
+ if has_bias:
+ C1 += bias
+
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+ B1, maxB = F.vectorwise_quant(B, dim=1)
+
+ C2 = F.int8_linear_matmul(A1, B1)
+
+ C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
+ if has_bias:
+ C4 += bias
+
+ # TODO: is something wrong here? If so, the problem goes deeper
+ # n = C1.numel()
+ # p = 0.06
+ std = C1.std(0).view(1, -1)
+ C1 /= std
+ C4 /= std
+ # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06))
+ # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
+
+ C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias)
+ C5 /= std
+ torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
+ n = C5.numel()
+ assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))
+
+ @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2"))
+ @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims"))
+ @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp"))
+ def test_colrow_absmax(self, dim1, dim2, dims, threshold):
+ for i in range(k):
+ A = torch.randn(dim1, dim2, device="cuda").half()
+
+ assert dims == 2
+
+ row_stats1, _ = torch.abs(A.float()).max(1)
+ col_stats1, _ = torch.abs(A.float()).max(0)
+
+ if threshold > 0.0:
+ A_truncated = A.clone()
+ A_truncated[torch.abs(A_truncated) >= threshold] = 0.0
+ row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
+ col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
+
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
+
+ nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten()
+ nnz_block_ptr1 = torch.zeros(
+ nnz_rows1_counts.shape[0] + 1,
+ dtype=nnz_rows1_counts.dtype,
+ device=nnz_rows1_counts.device,
+ )
+ nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
+
+ torch.testing.assert_close(col_stats1_trunc, col_stats2)
+ torch.testing.assert_close(row_stats1_trunc, row_stats2)
+ # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2)
+ else:
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
+ assert nnz_block_ptr2 is None
+ torch.testing.assert_close(col_stats1, col_stats2)
+ torch.testing.assert_close(row_stats1, row_stats2)
+
+ @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
+ def test_int8_double_quant(self, dim1, dim2):
+ for i in range(k):
+ A = torch.randn(dim1, dim2, device="cuda").half()
+ out_col1, Scol = F.vectorwise_quant(A, dim=0)
+ out_row1, Srow = F.vectorwise_quant(A, dim=1)
+
+ CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A)
+
+ # max difference is 1 due to rounding differences
+ torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
+ torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
+
+ n = CAt.numel()
+ num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
+ num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
+
+ # allow for 1:500 error due to rounding differences
+ min_error = 1 / 500
+ if num_not_close_cols > (min_error * n):
+ print(
+ f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
+ )
+ assert False
+ if num_not_close_rows > (min_error * n):
+ print(
+ f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
+ )
+ assert False
+
+ torch.testing.assert_close(Srow.flatten().float(), statsA)
+ torch.testing.assert_close(Scol.flatten().float(), statsAt)
+
+ @pytest.mark.parametrize(
+ ("dim1", "dim4", "inner"),
+ (
+ pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
+ for (dim1, dim4, inner) in zip(
+ (1, 8, 2048, 4096),
+ (2, 128, 2048, 4096),
+ (4, 256, 512, 4096),
)
- nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
-
- torch.testing.assert_close(col_stats1_trunc, col_stats2)
- torch.testing.assert_close(row_stats1_trunc, row_stats2)
- # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2)
- else:
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
- assert nnz_block_ptr2 is None
- torch.testing.assert_close(col_stats1, col_stats2)
- torch.testing.assert_close(row_stats1, row_stats2)
-
-
-@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
-def test_int8_double_quant(dim1, dim2):
- for i in range(k):
+ ),
+ )
+ def test_integrated_int8_linear_matmul(self, dim1, dim4, inner):
+ for i in range(k):
+ A = torch.randn(dim1, inner, device="cuda").half()
+ B = torch.randn(dim4, inner, device="cuda").half()
+
+ out1 = torch.matmul(A.half(), B.t().half())
+
+ C1a, stats1a, _ = F.int8_vectorwise_quant(A)
+ C2a, stats2a, _ = F.int8_vectorwise_quant(B)
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+ B1, maxB = F.vectorwise_quant(B, dim=1)
+
+ torch.testing.assert_close(maxA.flatten().float(), stats1a)
+ torch.testing.assert_close(maxB.flatten().float(), stats2a)
+ torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
+ torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
+
+ out2 = F.int8_linear_matmul(A1, B1)
+
+ C2 = F.int8_linear_matmul(A1, B1)
+
+ out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
+
+ err1 = torch.abs(out1 - out2).mean().item()
+ err2 = torch.abs(out1 - out3).mean().item()
+ assert err2 <= err1 * 1.025
+
+ @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
+ def test_coo_double_quant(self, dim1, dim2):
+ threshold = 2.00
+ for i in range(k):
+ A = torch.randn(dim1, dim2, device="cuda").half()
+
+ idx = torch.abs(A) >= threshold
+ CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
+
+ if outlier_cols is not None:
+ A1 = A * idx
+ A2 = torch.zeros_like(A) + A1
+ torch.testing.assert_close(A1, A2)
+
+ A[:, outlier_cols] = 0
+ A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
+ torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
+
+ @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
+ def test_coo_int8_vectorwise_quant(self, dim1, dim2):
+ threshold = 3.00
+ for i in range(k):
+ A = torch.randn(dim1, dim2, device="cuda").half()
+
+ idx = torch.abs(A) >= threshold
+ CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
+
+ if outlier_cols is not None:
+ A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
+ A[:, outlier_cols] = 0
+ torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
+
+
+class TestSpMMFunctional:
+ @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
+ @pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
+ def test_spmm_coo(self, dim1, dim2, transposed_B):
+ threshold = 1.5
+ dim3 = torch.randint(32, 128, size=(1,)).item()
+ # dim3 = 17
+ for i in range(k):
+ A = torch.randn(dim1, dim2).cuda().half()
+ if transposed_B:
+ B = torch.randn(dim3, dim2).cuda().half()
+ else:
+ B = torch.randn(dim2, dim3).cuda().half()
+
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A * idx
+
+ if transposed_B:
+ out2 = F.spmm_coo(cooA, B.t())
+ out1 = torch.matmul(A2, B.t())
+ else:
+ out2 = F.spmm_coo(cooA, B)
+ out1 = torch.matmul(A2, B)
+
+ assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
+
+ @pytest.mark.benchmark
+ def test_spmm_bench(self):
+ batch = 2
+ model = 1024 * 1
+ hidden = model * 4
+ seq = 1024
+ dim1 = batch * seq
+ dim2 = model
+ dim3 = hidden
+ threshold = 4
A = torch.randn(dim1, dim2, device="cuda").half()
- out_col1, Scol = F.vectorwise_quant(A, dim=0)
- out_row1, Srow = F.vectorwise_quant(A, dim=1)
-
- CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A)
-
- # max difference is 1 due to rounding differences
- torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
- torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
-
- n = CAt.numel()
- num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
- num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
-
- # allow for 1:500 error due to rounding differences
- min_error = 1 / 500
- if num_not_close_cols > (min_error * n):
- print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}")
- assert False
- if num_not_close_rows > (min_error * n):
- print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}")
- assert False
-
- torch.testing.assert_close(Srow.flatten().float(), statsA)
- torch.testing.assert_close(Scol.flatten().float(), statsAt)
-
-
-@pytest.mark.parametrize(
- ("dim1", "dim4", "inner"),
- (
- pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
- for (dim1, dim4, inner) in zip(
- (1, 8, 2048, 4096),
- (2, 128, 2048, 4096),
- (4, 256, 512, 4096),
- )
- ),
-)
-def test_integrated_int8_linear_matmul(dim1, dim4, inner):
- for i in range(k):
- A = torch.randn(dim1, inner, device="cuda").half()
- B = torch.randn(dim4, inner, device="cuda").half()
-
- out1 = torch.matmul(A.half(), B.t().half())
-
- C1a, stats1a, _ = F.int8_vectorwise_quant(A)
- C2a, stats2a, _ = F.int8_vectorwise_quant(B)
- A1, maxA = F.vectorwise_quant(A, dim=1)
- B1, maxB = F.vectorwise_quant(B, dim=1)
-
- torch.testing.assert_close(maxA.flatten().float(), stats1a)
- torch.testing.assert_close(maxB.flatten().float(), stats2a)
- torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
- torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
-
- out2 = F.int8_linear_matmul(A1, B1)
-
- C2 = F.int8_linear_matmul(A1, B1)
+ B = torch.randn(dim2, dim3, device="cuda").half()
+ for i in range(10):
+ C1 = bnb.matmul(A, B.t())
- out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t())
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ C1 = bnb.matmul(A, B.t())
+ torch.cuda.synchronize()
+ t8 = time.time() - t0
- err1 = torch.abs(out1 - out2).mean().item()
- err2 = torch.abs(out1 - out3).mean().item()
- assert err2 <= err1 * 1.025
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ print(nnz / idx.numel())
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ for i in range(10):
+ out2 = F.spmm_coo(cooA, B)
-@pytest.mark.parametrize(
- ("dim1", "dim4", "inner"),
- (
- pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}")
- for (dim1, dim4, inner) in zip(
- get_test_dims(1, 4 * 1024, n=6),
- get_test_dims(1, 4 * 1024, n=6),
- get_test_dims(1, 4 * 1024, n=6),
- )
- ),
-)
-@pytest.mark.skip("Row scale has some bugs for ampere")
-def test_igemmlt_row_scale(dim1, dim4, inner):
- formatB = F.get_special_format_str()
- err1, err2, err3 = [], [], []
- relerr1, relerr2 = [], []
- scale = 1
- for i in range(k):
- A = torch.randn(dim1, inner, device="cuda").half()
- B = torch.randn(dim4, inner, device="cuda").half()
- torch.nn.init.xavier_uniform_(B)
- C1 = torch.matmul(A, B.t())
-
- out1 = torch.matmul(A.half(), B.t().half())
-
- C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
- CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
- A2, SA = F.nvidia_transform(C1a, "col32")
- B2, SB = F.nvidia_transform(CB, formatB)
- A1, maxA = F.vectorwise_quant(A, dim=1)
-
- c = 10.0 * inner * scale
- row_scale = torch.ones_like(maxA) / c
- outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
- # C3, S = F.nvidia_transform(outC32, "row", state=SC)
- C3 = outC32
- maxval = torch.abs(C3).max()
- if maxval == 127:
- scale = 1.5
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ out2 = F.spmm_coo(cooA, B)
+ torch.cuda.synchronize()
+ tsp = time.time() - t0
+ print(tsp, t8)
+ print(tsp / t8)
+
+ @pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
+ @pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
+ @pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
+ def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func):
+ out_func = getattr(torch, out_func)
+
+ threshold = 3.3
+ # threshold = 2.8
+ # threshold = 0.0
+ A = torch.randn(dim1, dim2, device="cuda").half()
+ if dtype == torch.float16:
+ B = torch.randn(dim2, dim2 * 4, device="cuda").half()
+ torch.nn.init.xavier_uniform_(B)
else:
- scale = maxval / 120
- out3 = C3 * maxA * absmaxB * c / (127 * 127)
-
- C4 = torch.matmul(C1a.float(), CB.float().t())
-
- C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
- B2, SB = F.nvidia_transform(C2a, formatB)
- outC32 = F.int8_linear_matmul(A2, B2)
- out2 = F.int8_mm_dequant(outC32, stats1a, stats2a)
-
- CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
- CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
-
- C = torch.matmul(CA.float(), CB.t().float())
- out4 = C * SA * SB / (127 * 127)
- # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
-
- # print('='*80)
- # print(out1)
- # print(out2)
- # print(out3)
+ B = torch.randn(dim2, dim2 * 4, device="cuda").half()
+ torch.nn.init.xavier_uniform_(B)
+ B, SB = F.vectorwise_quant(B, quant_type="linear")
+ # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
+ print("")
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A * idx
+ out1 = torch.matmul(A2.half(), B.half())
+ out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
+ out1 += out.clone()
+ out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
+ # print(B)
# print(out1)
# print(out2)
- # print(out3)
- err1.append(torch.abs(out1 - out2).mean().item())
- err2.append(torch.abs(out1 - out3).mean().item())
- err3.append(torch.abs(out1 - out4).mean().item())
-
- # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
- print("")
- print(sum(err1) / len(err1))
- print(sum(err2) / len(err2))
- print(sum(err3) / len(err3))
-
-
-@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
-def test_coo_double_quant(dim1, dim2):
- threshold = 2.00
- for i in range(k):
- A = torch.randn(dim1, dim2, device="cuda").half()
-
- idx = torch.abs(A) >= threshold
- CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
-
- if outlier_cols is not None:
- A1 = A * idx
- A2 = torch.zeros_like(A) + A1
- torch.testing.assert_close(A1, A2)
-
- A[:, outlier_cols] = 0
- A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
- torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2)
-
-
-@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2"))
-def test_coo_int8_vectorwise_quant(dim1, dim2):
- threshold = 3.00
- for i in range(k):
+ p = 200 / (2048 * 12288 * 4)
+ n = out1.numel()
+ count = math.ceil(p * n)
+ std = out1.std()
+ out1 /= std
+ out2 /= std
+ assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
+ # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
+
+ idx_col = torch.randint(0, A2.shape[-1], size=(15,))
+
+ # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
+
+ # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # print(A2.shape, B.shape)
+ # for i in range(100):
+ # #out3 = F.spmm_coo(cooA, Bt.t())
+ # #out2 = F.spmm_coo(cooA, B)
+ # #out2 = F.spmm_coo_very_sparse(cooA, B)
+ # #out1 = torch.matmul(A, Bt.t())
+
+ # torch.cuda.synchronize()
+ # print(time.time() - t0)
+
+ @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
+ @pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
+ def test_integrated_sparse_decomp(self, dim1, dim2):
+ threshold = 3.0
+ for _ in range(k):
+ A = torch.randn(dim1, dim2).cuda().half()
+ w1 = torch.randn(dim1, dim2).cuda().half()
+ out1 = torch.matmul(A, w1.t())
+
+ Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
+ CA, statsA, _ = F.int8_vectorwise_quant(A)
+
+ out1_32 = F.int8_linear_matmul(CA, Cw1)
+ out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)
+
+ # CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
+ CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
+
+ out1_32 = F.int8_linear_matmul(CA, Cw1)
+ out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
+
+ assert coo_tensor is not None
+
+ out4 = F.spmm_coo(coo_tensor, w1.t())
+ # idx = torch.unique(coo_tensor._indices()[1]).long()
+ # out4 = torch.matmul(A, w1.t())
+ out5 = out3 + out4
+
+ err1 = torch.abs(out1 - out2).mean().item()
+ err2 = torch.abs(out1 - out5).mean().item()
+ assert err2 < err1
+
+ @pytest.mark.parametrize("dim1", [1 * 2048])
+ @pytest.mark.parametrize("dim2", [2048])
+ @pytest.mark.parametrize("dtype", [torch.int8])
+ def test_spmm_coo_dequant(self, dim1, dim2, dtype):
+ threshold = 6.0
+ # threshold = 2.8
+ # threshold = 0.0
A = torch.randn(dim1, dim2, device="cuda").half()
+ B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
+ torch.nn.init.xavier_uniform_(B)
+ Bt = B.t().contiguous()
- idx = torch.abs(A) >= threshold
- CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
-
- if outlier_cols is not None:
- A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
- A[:, outlier_cols] = 0
- torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
+ CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B)
+ rowidx = torch.randint(0, A.shape[-1], size=(15,))
-@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2"))
-@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B"))
-def test_spmm_coo(dim1, dim2, transposed_B):
- threshold = 1.5
- dim3 = torch.randint(32, 128, size=(1,)).item()
- # dim3 = 17
- for i in range(k):
- A = torch.randn(dim1, dim2).cuda().half()
- if transposed_B:
- B = torch.randn(dim3, dim2).cuda().half()
- else:
- B = torch.randn(dim2, dim3).cuda().half()
+ A[:, rowidx] = 8.0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
@@ -751,712 +947,381 @@ def test_spmm_coo(dim1, dim2, transposed_B):
values = A[idx]
cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
A2 = A * idx
-
- if transposed_B:
- out2 = F.spmm_coo(cooA, B.t())
- out1 = torch.matmul(A2, B.t())
- else:
- out2 = F.spmm_coo(cooA, B)
- out1 = torch.matmul(A2, B)
-
- assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
-
-
-@pytest.mark.benchmark
-def test_spmm_bench():
- batch = 2
- model = 1024 * 1
- hidden = model * 4
- seq = 1024
- dim1 = batch * seq
- dim2 = model
- dim3 = hidden
- threshold = 4
- A = torch.randn(dim1, dim2, device="cuda").half()
- B = torch.randn(dim2, dim3, device="cuda").half()
- for i in range(10):
- C1 = bnb.matmul(A, B.t())
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(k):
- C1 = bnb.matmul(A, B.t())
- torch.cuda.synchronize()
- t8 = time.time() - t0
-
- idx = torch.abs(A) >= threshold
- nnz = (idx == 1).sum().item()
- print(nnz / idx.numel())
- rows, cols = torch.where(idx)
- values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
-
- for i in range(10):
- out2 = F.spmm_coo(cooA, B)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(k):
- out2 = F.spmm_coo(cooA, B)
- torch.cuda.synchronize()
- tsp = time.time() - t0
- print(tsp, t8)
- print(tsp / t8)
-
-
-@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2"))
-def test_integrated_sparse_decomp(dim1, dim2):
- threshold = 3.0
- for _ in range(k):
- A = torch.randn(dim1, dim2).cuda().half()
- w1 = torch.randn(dim1, dim2).cuda().half()
- out1 = torch.matmul(A, w1.t())
-
- Cw1, statsw1, _ = F.int8_vectorwise_quant(w1)
- CA, statsA, _ = F.int8_vectorwise_quant(A)
-
- out1_32 = F.int8_linear_matmul(CA, Cw1)
- out2 = F.int8_mm_dequant(out1_32, statsA, statsw1)
-
- # CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold)
- CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold)
-
- out1_32 = F.int8_linear_matmul(CA, Cw1)
- out3 = F.int8_mm_dequant(out1_32, statsA, statsw1)
-
- assert coo_tensor is not None
-
- out4 = F.spmm_coo(coo_tensor, w1.t())
- # idx = torch.unique(coo_tensor._indices()[1]).long()
- # out4 = torch.matmul(A, w1.t())
- out5 = out3 + out4
-
- err1 = torch.abs(out1 - out2).mean().item()
- err2 = torch.abs(out1 - out5).mean().item()
- assert err2 < err1
-
-
-def test_matmuls():
- a = torch.randn(256, 512).half().cuda()
- b = torch.randn(256, 512).half().cuda()
- c1 = torch.matmul(a, b.t())
- c2 = bnb.matmul(a, b)
- c3 = bnb.matmul_cublas(a, b.t())
-
- err1 = torch.abs(c1 - c2).mean().item()
- err2 = torch.abs(c1 - c3).mean().item()
- assert err1 < 0.2
- assert err2 < 0.2
- print(err1, err2)
-
-
-@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2"))
-@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype)
-@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func"))
-def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
- out_func = getattr(torch, out_func)
-
- threshold = 3.3
- # threshold = 2.8
- # threshold = 0.0
- A = torch.randn(dim1, dim2, device="cuda").half()
- if dtype == torch.float16:
- B = torch.randn(dim2, dim2 * 4, device="cuda").half()
- torch.nn.init.xavier_uniform_(B)
- else:
- B = torch.randn(dim2, dim2 * 4, device="cuda").half()
- torch.nn.init.xavier_uniform_(B)
- B, SB = F.vectorwise_quant(B, quant_type="linear")
- # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
-
- print("")
- idx = torch.abs(A) >= threshold
- nnz = (idx == 1).sum().item()
- rows, cols = torch.where(idx)
- values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A * idx
- out1 = torch.matmul(A2.half(), B.half())
- out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
- out1 += out.clone()
- out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
- # print(B)
- # print(out1)
- # print(out2)
- p = 200 / (2048 * 12288 * 4)
- n = out1.numel()
- count = math.ceil(p * n)
- std = out1.std()
- out1 /= std
- out2 /= std
- assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
- # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
-
- idx_col = torch.randint(0, A2.shape[-1], size=(15,))
-
- # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
-
- # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
- # torch.cuda.synchronize()
- # t0 = time.time()
- # print(A2.shape, B.shape)
- # for i in range(100):
- # #out3 = F.spmm_coo(cooA, Bt.t())
- # #out2 = F.spmm_coo(cooA, B)
- # #out2 = F.spmm_coo_very_sparse(cooA, B)
- # #out1 = torch.matmul(A, Bt.t())
-
- # torch.cuda.synchronize()
- # print(time.time() - t0)
-
-
-def test_coo2csr():
- threshold = 1
- A = torch.randn(128, 128).half().cuda()
- idx = torch.abs(A) >= threshold
- nnz = (idx == 1).sum().item()
- rows, cols = torch.where(idx)
- values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A * idx
- csrA = F.coo2csr(cooA)
- counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
- assert counts.numel() == A.shape[0]
-
- torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
- idx = A2 != 0
- torch.testing.assert_close(A2[idx], csrA.values)
-
-
-def test_coo2csc():
- threshold = 1
- A = torch.randn(128, 128).half().cuda()
- idx = torch.abs(A) >= threshold
- nnz = (idx == 1).sum().item()
- rows, cols = torch.where(idx)
- values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A * idx
- cscA = F.coo2csc(cooA)
- counts = cscA.colptr[1:] - cscA.colptr[:-1]
- assert counts.numel() == A.shape[1]
-
- torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
- # torch uses row-major -> use transpose to transfer to col-major
- idx = A2.t() != 0
- torch.testing.assert_close(A2.t()[idx], cscA.values)
-
-
-@pytest.mark.parametrize("dim1", [1 * 2048])
-@pytest.mark.parametrize("dim2", [2048])
-@pytest.mark.parametrize("dtype", [torch.int8])
-def test_spmm_coo_dequant(dim1, dim2, dtype):
- threshold = 6.0
- # threshold = 2.8
- # threshold = 0.0
- A = torch.randn(dim1, dim2, device="cuda").half()
- B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
- torch.nn.init.xavier_uniform_(B)
- Bt = B.t().contiguous()
-
- CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B)
-
- rowidx = torch.randint(0, A.shape[-1], size=(15,))
-
- A[:, rowidx] = 8.0
-
- idx = torch.abs(A) >= threshold
- nnz = (idx == 1).sum().item()
- rows, cols = torch.where(idx)
- values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A * idx
- out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
- out1 = torch.matmul(A2, B.half())
- out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
- out3 = out3 * statsBt.half() / 127
-
- values, counts = torch.unique(cooA.rowidx, return_counts=True)
- offset = counts.cumsum(0).int()
- max_count, max_idx = torch.sort(counts, descending=True)
- print(torch.median(max_count.float()))
-
- torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
-
- p = 200 / (2048 * 12288 * 4)
- n = out1.numel()
- count = math.ceil(p * n)
- assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
-
- # torch.cuda.synchronize()
- # t0 = time.time()
- # for i in range(100):
- # out2 = F.spmm_coo_very_sparse(cooA, B)
- # torch.cuda.synchronize()
- # print('fp16', time.time() - t0)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(100):
- out2 = F.spmm_coo(cooA, B)
- torch.cuda.synchronize()
- print("cusparse fp16", time.time() - t0)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(100):
- out2 = F.spmm_coo_very_sparse(cooA, CBt)
- torch.cuda.synchronize()
- print("int8", time.time() - t0)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(100):
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
- torch.cuda.synchronize()
- print("int8+dequant", time.time() - t0)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(100):
- out2 = torch.matmul(A, B)
- torch.cuda.synchronize()
- print("matmul", time.time() - t0)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(100):
- out1 = bnb.matmul(A, Bt)
- out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
- out = out1 + out2
- torch.cuda.synchronize()
- print("sparse+ matmul", time.time() - t0)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(100):
- out1 = bnb.matmul(A, Bt)
- torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
- torch.cuda.synchronize()
- print("partial matmul", time.time() - t0)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(100):
- out1 = bnb.matmul(A, Bt)
- torch.cuda.synchronize()
- print("partial matmul", time.time() - t0)
-
-
-def test_zeropoint():
- def quant_zp(x):
- dtype = x.dtype
- x = x.float()
- dyna = x.max() - x.min()
- if dyna == 0:
- dyna = 1
- qx = 254.0 / dyna
- minx = x.min()
- # zpx = torch.round(minx* qx)
- # zpx = 127 - torch.round(x.max()* qx)
- zpx = torch.round(x.min() * qx) - 127
- x = (qx * x) + zpx
- return x, qx, zpx
-
- batch = 2
- seq = 512
- model = 1024
- hidden = 4 * model
- A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
- B = torch.randn(model, hidden, device="cuda").half() * 0.1
-
- C0 = torch.matmul(A, B)
-
- # A, SA = F.vectorwise_quant(A, quant_type='linear')
- # B, SB = F.vectorwise_quant(B, quant_type='linear')
- A = A.float()
- B = B.float()
-
- C1 = torch.matmul(A, B)
- C3 = bnb.matmul(A.half(), B.t().contiguous().half())
-
- zp = 1
- # C2 = torch.matmul(A-zp, B)
- # C2 += B.sum(0).view(1, -1)*zp
- C2 = torch.matmul(A, B - zp)
- C2 -= A.sum(1).view(-1, 1) * zp
-
- ca, cqa, cza = quant_zp(A)
- # print(ca.min(), ca.max())
- # print((ca - cza).min(), (ca - cza).max())
-
- zp = 1
- scale = 2.0
- C5 = torch.matmul((A * scale) - zp, B)
- C5 += B.sum(0) * zp
- C5 /= scale
-
- CA, qa, zpa = quant_zp(A)
- C4 = torch.matmul(CA, B)
- C4 -= B.sum(0) * zpa
- C4 /= qa
-
- zpb = 1
- zpa = 1
- qa = 2
- qb = 2
- C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
- C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
- C6 -= zpa * zpb * A.shape[1]
- C6 /= qa * qb
-
- CA, qa, zpa = quant_zp(A)
- CB, qb, zpb = quant_zp(B)
- C7 = torch.matmul(CA, CB)
- C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
- C7 -= zpa * zpb * A.shape[1]
- C7 /= qa * qb
-
- # print("")
- # print(C0.flatten()[:10])
- # print(C1.flatten()[:10])
- # print(C2.flatten()[:10])
- # print(C3.flatten()[:10])
- # print(C5.flatten()[:10])
- # print(C6.flatten()[:10])
- # print(C7.flatten()[:10])
- err1 = torch.abs(C1 - C2).mean().item()
- err2 = torch.abs(C1 - C3).mean().item()
- err3 = torch.abs(C1 - C4).mean().item()
- err4 = torch.abs(C1 - C5).mean().item()
- err5 = torch.abs(C1 - C6).mean().item()
- err6 = torch.abs(C1 - C7).mean().item()
- print(err1, err2, err3, err4, err5, err6)
-
-
-@pytest.mark.deprecated
-def test_extract_outliers():
- for i in range(k):
- shapeA = (4096, 4096 * 4)
- idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
- # idx = torch.Tensor([0]).int().cuda()
- A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
- outliers1 = A[:, idx.long()]
-
- CA, SA = F.transform(A, "col_turing")
-
- outliers2 = F.extract_outliers(CA, SA, idx)
+ out1 = torch.matmul(A2, B.half())
+ out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
+ out3 = out3 * statsBt.half() / 127
+
+ values, counts = torch.unique(cooA.rowidx, return_counts=True)
+ offset = counts.cumsum(0).int()
+ max_count, max_idx = torch.sort(counts, descending=True)
+ print(torch.median(max_count.float()))
+
+ torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
+
+ p = 200 / (2048 * 12288 * 4)
+ n = out1.numel()
+ count = math.ceil(p * n)
+ assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
+
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # for i in range(100):
+ # out2 = F.spmm_coo_very_sparse(cooA, B)
+ # torch.cuda.synchronize()
+ # print('fp16', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo(cooA, B)
+ torch.cuda.synchronize()
+ print("cusparse fp16", time.time() - t0)
- assert outliers2.shape[0] == shapeA[0]
- assert outliers2.shape[1] == idx.numel()
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo_very_sparse(cooA, CBt)
+ torch.cuda.synchronize()
+ print("int8", time.time() - t0)
- torch.testing.assert_close(outliers1, outliers2)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ torch.cuda.synchronize()
+ print("int8+dequant", time.time() - t0)
- CA, SA = F.transform(A, "col_ampere")
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = torch.matmul(A, B)
+ torch.cuda.synchronize()
+ print("matmul", time.time() - t0)
- outliers2 = F.extract_outliers(CA, SA, idx)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ out = out1 + out2
+ torch.cuda.synchronize()
+ print("sparse+ matmul", time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
+ torch.cuda.synchronize()
+ print("partial matmul", time.time() - t0)
- assert outliers2.shape[0] == shapeA[0]
- assert outliers2.shape[1] == idx.numel()
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ torch.cuda.synchronize()
+ print("partial matmul", time.time() - t0)
- torch.testing.assert_close(outliers1, outliers2)
+class TestSparseTensorFunctional:
+ def test_coo2csr(self):
+ threshold = 1
+ A = torch.randn(128, 128).half().cuda()
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A * idx
+ csrA = F.coo2csr(cooA)
+ counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
+ assert counts.numel() == A.shape[0]
-def test_blockwise_cpu_large():
- diffs = []
- reldiffs = []
- batch = 128
- seq = 128
- for hidden in [128]: # , 14336]:
- for blocksize in [4096, 16384]:
- for i in range(2):
- A1 = torch.randn(batch, seq, hidden, device="cpu")
- t0 = time.time()
- C, S = F.quantize_blockwise(A1, blocksize=blocksize)
- A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
- print(time.time() - t0)
- diff = torch.abs(A1 - A2)
- reldiff = diff / torch.abs(A1 + 1e-8)
- diffs.append(diff.mean().item())
- reldiffs.append(reldiff.mean().item())
- assert diffs[-1] < 0.011
- # print(sum(diffs)/len(diffs))
- # print(sum(reldiffs)/len(reldiffs))
+ torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
+ idx = A2 != 0
+ torch.testing.assert_close(A2[idx], csrA.values)
+ def test_coo2csc(self):
+ threshold = 1
+ A = torch.randn(128, 128).half().cuda()
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A * idx
+ cscA = F.coo2csc(cooA)
+ counts = cscA.colptr[1:] - cscA.colptr[:-1]
+ assert counts.numel() == A.shape[1]
-def test_fp8_quant():
- for e_bits in range(1, 7):
- p_bits = 7 - e_bits
- code = F.create_fp8_map(True, e_bits, p_bits).cuda()
+ torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
+ # torch uses row-major -> use transpose to transfer to col-major
+ idx = A2.t() != 0
+ torch.testing.assert_close(A2.t()[idx], cscA.values)
- abserr = []
- relerr = []
- for i in range(100):
- A1 = torch.randn(1024, 1024, device="cuda")
- C, SC = F.quantize_blockwise(A1, code=code)
- A2 = F.dequantize_blockwise(C, SC)
- diff = torch.abs(A1 - A2)
- reldiff = diff / torch.abs(A1 + 1e-8)
- abserr.append(diff.mean().item())
- relerr.append(reldiff.mean().item())
- # assert diff < 0.0075
- # print(sum(abserr)/len(abserr))
- # print(sum(relerr)/len(relerr))
-
- abserr = []
- relerr = []
- for i in range(100):
- A1 = torch.rand(1024, 1024, device="cuda")
- C, SC = F.quantize_blockwise(A1, code=code)
- A2 = F.dequantize_blockwise(C, SC)
- diff = torch.abs(A1 - A2)
- reldiff = diff / torch.abs(A1 + 1e-8)
- abserr.append(diff.mean().item())
- relerr.append(reldiff.mean().item())
- # assert diff < 0.0075
- # print(sum(abserr)/len(abserr))
- # print(sum(relerr)/len(relerr))
-
- abserr = []
- relerr = []
- for i in range(100):
- A1 = torch.randn(1024, 1024, device="cuda")
- C, SC = F.quantize_blockwise(A1)
- A2 = F.dequantize_blockwise(C, SC)
- diff = torch.abs(A1 - A2)
- reldiff = diff / torch.abs(A1 + 1e-8)
- abserr.append(diff.mean().item())
- relerr.append(reldiff.mean().item())
- # assert diff < 0.0075
- # print(3, sum(abserr)/len(abserr))
- # print(3, sum(relerr)/len(relerr))
-
-
-def test_few_bit_quant():
- # print('')
- for bits in range(2, 9):
- # print('='*30, bits, '='*30)
- for method in ["linear", "fp8", "dynamic", "quantile"]:
- abserrs = []
- relerrs = []
- code = None
- if method == "linear":
- code = F.create_linear_map(True, total_bits=bits).cuda()
- elif method == "fp8":
- ebits = math.ceil(bits / 2)
- pbits = bits - ebits - 1
- code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
- elif method == "dynamic":
- code = F.create_dynamic_map(True, bits - 0, bits).cuda()
- elif method == "quantile":
- values = torch.randn(2048, 2048, device="cuda")
- code = F.create_quantile_map(values, bits).cuda()
- # for some data types we have no zero
- # for some data types we have one zero
- # for some data types we have two zeros
- assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}"
- # print(method, (code==0).sum())
- assert code.numel() == 256
- for i in range(10):
- values = torch.randn(1, 32, device="cuda")
- values /= values.abs().max()
- # values[values.abs() < 1e-6] += 1e-5
-
- q1 = []
- v1 = []
- for v in values[0]:
- idx = torch.abs(v - code).argmin()
- q1.append(idx.item())
- v1.append(code[idx].item())
-
- q1 = torch.Tensor(q1).cuda()
- v1 = torch.Tensor(v1).cuda()
-
- q2, S2 = F.quantize_blockwise(values, code=code)
- v2 = F.dequantize_blockwise(q2, S2)
-
- idx = torch.isclose(q1.int(), q2.int())
- err2 = torch.abs(v2 - values)
- abserrs.append(err2.mean().item())
- relerrs.append((err2 / (1e-10 + values).abs()).mean().item())
- if idx.sum():
- # some weird cases
- err1 = torch.abs(v1 - values).mean()
- # assert err2.mean() <= err1
- else:
- torch.testing.assert_close(q1, q2)
- # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
- # assert False
-
-
-def test_kbit_quantile_estimation():
- for i in range(100):
- data = torch.randn(1024, 1024, device="cuda")
- for bits in range(2, 9):
- p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits)
- val1 = torch.Tensor(norm.ppf(p)).cuda()
- val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
- err = torch.abs(val1 - val2).mean()
- assert err < 0.038
-
- for i in range(100):
- data = torch.randn(1024, 1024, device="cuda")
- for bits in range(2, 4):
- total_values = 2**bits - 1
- p = np.linspace(0, 1, 2 * total_values + 1)
- idx = np.arange(1, 2 * total_values + 1, 2)
- p = p[idx]
- offset = 1 / (2 * total_values)
- p = np.linspace(offset, 1 - offset, total_values)
- val1 = torch.Tensor(norm.ppf(p)).cuda()
- val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1)
- err = torch.abs(val1 - val2).mean()
- assert err < 0.035
-
-
-@pytest.mark.benchmark
-def test_bench_dequantization():
- a = torch.rand(1024, 1024, device="cuda").half()
- code = F.create_fp8_map(True, 3, 0, 4).cuda()
- qa, SA = F.quantize_blockwise(a, code=code)
- print(qa.max())
-
- max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000
- # print(max_theoretical_mu)
-
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(100):
- qa, SA = F.quantize_blockwise(a)
- torch.cuda.synchronize()
- # print((time.time()-t0)/1e6)
-
-
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
-@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
-@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
-def test_4bit_quant(dtype, quant_type, blocksize):
- vals = list(product([0, 1], repeat=4))
-
- code = {}
- for bits in vals:
- result = 0
- bias = 3
- sign, e1, e2, p1 = bits
- idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1
- sign = -1.0 if sign else 1.0
- exp = e1 * 2 + e2 * 1
- if exp == 0:
- # sub-normal
- if p1 == 0:
- result = 0
- else:
- result = sign * 0.0625
- else:
- # normal
- exp = 2 ** (-exp + bias + 1)
- frac = 1.5 if p1 else 1.0
- result = sign * exp * frac
- code[idx] = result
-
- A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
- qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
- A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
-
- err = (A1 - A2).abs().float()
- relerr = (err / (A1.abs().float() + 1e-8)).mean()
- idx = err > 1.0
- err = err.mean()
-
- assert A2.dtype == dtype
-
- # With larger block sizes, we can expect this to blow up.
- # At blocksize>=1024, don't even bother looking at relerr.
- if blocksize <= 64:
- assert err.item() < 0.1
- assert relerr.item() < 0.28
- elif blocksize <= 256:
- assert err.item() < 0.11
- assert relerr.item() < 0.30
- elif blocksize <= 512:
- assert err.item() < 0.12
- assert relerr.item() < 0.31
- elif quant_type == "fp4":
- # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
- assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
- else:
- # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
- assert err.item() < math.log2(blocksize) * 8e-2
-
-
-@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
-def test_4bit_compressed_stats(quant_type):
- for blocksize in [128, 64]:
- errs1 = []
- errs2 = []
- for i in range(10):
- A1 = torch.randn(1024, 1024, device="cuda").half()
- q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
- q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
- A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
- A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
+class TestQuantize4BitFunctional:
+ @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+ @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
+ @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
+ def test_4bit_quant(self, dtype, quant_type, blocksize):
+ A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
+ qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
+ A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
- err = (A1 - A2).abs().float()
- relerr = (err / (A1.abs().float() + 1e-15)).mean()
- err = err.mean()
+ err = (A1 - A2).abs().float()
+ relerr = (err / (A1.abs().float() + 1e-8)).mean()
+ err = err.mean()
- errs1.append(err.item())
+ assert A2.dtype == dtype
- assert err.item() < 0.11
+ # With larger block sizes, we can expect this to blow up.
+ # At blocksize>=1024, don't even bother looking at relerr.
+ if blocksize <= 64:
+ assert err.item() < 0.1
assert relerr.item() < 0.28
-
- err = (A1 - A3).abs().float()
- relerr = (err / (A1.abs().float() + 1e-15)).mean()
- err = err.mean()
-
- errs2.append(err.item())
-
+ elif blocksize <= 256:
assert err.item() < 0.11
- assert relerr.item() < 0.28
+ assert relerr.item() < 0.30
+ elif blocksize <= 512:
+ assert err.item() < 0.12
+ assert relerr.item() < 0.31
+ elif quant_type == "fp4":
+ # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
+ assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
+ else:
+ # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
+ assert err.item() < math.log2(blocksize) * 8e-2
+
+ @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
+ def test_4bit_compressed_stats(self, quant_type):
+ for blocksize in [128, 64]:
+ errs1 = []
+ errs2 = []
+ for i in range(10):
+ A1 = torch.randn(1024, 1024, device="cuda").half()
+ q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
+ q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
+ A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
+ A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
+
+ err = (A1 - A2).abs().float()
+ relerr = (err / (A1.abs().float() + 1e-15)).mean()
+ err = err.mean()
+
+ errs1.append(err.item())
+
+ assert err.item() < 0.11
+ assert relerr.item() < 0.28
+
+ err = (A1 - A3).abs().float()
+ relerr = (err / (A1.abs().float() + 1e-15)).mean()
+ err = err.mean()
+
+ errs2.append(err.item())
+
+ assert err.item() < 0.11
+ assert relerr.item() < 0.28
+
+ # print(sum(errs1)/len(errs1), blocksize, quant_type)
+ # print(sum(errs2)/len(errs2), blocksize, quant_type)
+
+ # @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
+ @pytest.mark.parametrize("quant_type", ["nf4"])
+ @pytest.mark.benchmark
+ def test_bench_4bit_dequant(self, quant_type):
+ blocksize = 256
+ a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
+ qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
+
+ input_size = a.numel() / 2
+ output_size = a.numel() * 2
+ num_bytes = input_size + output_size
+ GB = num_bytes / 1e9
+ max_theoretical_s = GB / 768
+ # print(max_theoretical_s*1e6)
+ b = torch.randn(128, 1024 * 12, device="cuda").half()
+
+ iters = 100
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(iters):
+ F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
+ # b.copy_(a)
+ torch.cuda.synchronize()
+ # print((time.time()-t0)/iters*1e6)
+
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # for i in range(iters):
+ # torch.matmul(b, a.t())
+ # torch.cuda.synchronize()
+ # print((time.time()-t0)/iters*1e6)
+
+ @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
+ @pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
+ @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
+ @pytest.mark.parametrize(
+ "quant_storage",
+ [torch.uint8, torch.float16, torch.bfloat16, torch.float32],
+ ids=describe_dtype,
+ )
+ def test_gemv_4bit(self, dtype, storage_type, quant_storage, double_quant, kind):
+ for dim in [128, 256, 512, 1024]:
+ # for dim in [4*1024]:
+ # for dim in [1*16]:
+ errs1 = []
+ errs2 = []
+ errs3 = []
+ relerrs1 = []
+ relerrs2 = []
+ relerrs3 = []
+ max_errs1 = []
+ max_errs2 = []
+ max_errs3 = []
+
+ for i in range(100):
+ if kind == "fc1":
+ A = torch.randn(1, dim, dtype=dtype, device="cuda")
+ B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
+ elif kind == "fc2":
+ A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
+ B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
+ elif kind == "attn":
+ A = torch.randn(1, dim, dtype=dtype, device="cuda")
+ B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
+ elif kind == "attn_packed":
+ A = torch.randn(1, dim, dtype=dtype, device="cuda")
+ B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
+
+ qB, state = F.quantize_4bit(
+ B,
+ quant_type=storage_type,
+ compress_statistics=double_quant,
+ quant_storage=quant_storage,
+ )
+ C3 = torch.matmul(A, B.t())
+ C2 = F.gemv_4bit(A, qB.t(), state=state)
+ A.requires_grad = True
+ C1 = bnb.matmul_4bit(A, qB.t(), state)
+
+ err1 = (C1 - C2).abs().float()
+ err2 = (C3 - C2).abs().float()
+ err3 = (C3 - C1).abs().float()
+
+ mag1 = torch.abs(C1).float() + 1e-5
+ mag2 = torch.abs(C3).float() + 1e-5
+ mag3 = torch.abs(C3).float() + 1e-5
+
+ relerr1 = err1 / mag1
+ relerr2 = err2 / mag2
+ relerr3 = err3 / mag3
+
+ max_err1 = err1.max()
+ max_err2 = err2.max()
+ max_err3 = err3.max()
+
+ errs1.append(err1.mean().item())
+ errs2.append(err2.mean().item())
+ errs3.append(err3.mean().item())
+
+ relerrs1.append(relerr1.mean().item())
+ relerrs2.append(relerr2.mean().item())
+ relerrs3.append(relerr3.mean().item())
+
+ max_errs1.append(max_err1.item())
+ max_errs2.append(max_err2.item())
+ max_errs3.append(max_err3.item())
+
+ c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
+
+ c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False)
+ err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
+ err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
+ err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
+ relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
+ relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
+ relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
+ maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
+ maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
+ maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
+ absratio = err2 / err3
+ relratio = relerr2 / relerr3
+ maxratio = relerr2 / relerr3
+
+ # for debugging if the tests fails
+ #
+ # print('='*80)
+ # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
+ # print(C1.flatten()[-20:])
+ # print(C2.flatten()[-20:])
+ # print(f'inference vs training abs: {err1}')
+ # print(f'inference vs training rel: {relerr1}')
+ # print(f'inference vs training max: {maxerr1}')
+ # print(f'inference vs training vs torch err ratio abs: {absratio}')
+ # print(f'inference vs training vs torch err ratio rel: {relratio}')
+ # print(f'inference vs training vs torch err ratio max: {maxratio}')
+ if dtype == torch.float16:
+ if dim <= 512:
+ assert err1 < 7e-5
+ assert relerr1 < 0.0008
+ else:
+ assert err1 < 6e-5
+ assert relerr1 < 2e-4
+ assert absratio < 1.005 and absratio > 0.995
+ assert relratio < 1.005 and relratio > 0.995
+ assert maxratio < 1.005 and maxratio > 0.995
+ elif dtype == torch.float32:
+ if dim <= 512:
+ assert err1 < 5e-8
+ assert relerr1 < 1e-6
+ assert maxerr1 < 1e-7
+ else:
+ assert err1 < 5e-8
+ assert relerr1 < 8e-6
+ assert maxerr1 < 1e-7
+ assert absratio < 1.005 and absratio > 0.995
+ assert relratio < 1.005 and relratio > 0.995
+ assert maxratio < 1.005 and maxratio > 0.995
+ elif dtype == torch.bfloat16:
+ if dim <= 512:
+ assert err1 < 6e-4
+ assert relerr1 < 0.007
+ assert maxerr1 < 0.015
+ else:
+ assert err1 < 2e-4
+ assert relerr1 < 0.002
+ assert maxerr1 < 0.0012
+ assert absratio < 1.005 and absratio > 0.995
+ assert relratio < 1.04 and relratio > 0.96
+ assert maxratio < 1.02 and maxratio > 0.98
+
+ @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
+ @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
+ def test_gemv_eye_4bit(self, storage_type, dtype, double_quant):
+ dims = 10
+ torch.random.manual_seed(np.random.randint(0, 412424242))
+ dims = get_test_dims(0, 8192, n=dims)
+ dims = [dim + (64 - (dim % 64)) for dim in dims]
+ # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
+ for dim in dims:
+ A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
+ B = torch.eye(dim, dtype=dtype, device="cuda")
+
+ qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
+ C3 = torch.matmul(A, B.t())
+ C2 = bnb.matmul_4bit(A, qB.t(), state)
+ A.requires_grad = True
+ C1 = bnb.matmul_4bit(A, qB.t(), state)
- # print(sum(errs1)/len(errs1), blocksize, quant_type)
- # print(sum(errs2)/len(errs2), blocksize, quant_type)
-
-
-# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
-@pytest.mark.parametrize("quant_type", ["nf4"])
-@pytest.mark.benchmark
-def test_bench_4bit_dequant(quant_type):
- blocksize = 256
- a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half()
- qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
-
- input_size = a.numel() / 2
- output_size = a.numel() * 2
- num_bytes = input_size + output_size
- GB = num_bytes / 1e9
- max_theoretical_s = GB / 768
- # print(max_theoretical_s*1e6)
- b = torch.randn(128, 1024 * 12, device="cuda").half()
-
- iters = 100
- torch.cuda.synchronize()
- t0 = time.time()
- for i in range(iters):
- F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
- # b.copy_(a)
- torch.cuda.synchronize()
- # print((time.time()-t0)/iters*1e6)
-
- # torch.cuda.synchronize()
- # t0 = time.time()
- # for i in range(iters):
- # torch.matmul(b, a.t())
- # torch.cuda.synchronize()
- # print((time.time()-t0)/iters*1e6)
+ torch.testing.assert_close(A, C3)
+ torch.testing.assert_close(A, C1)
+ torch.testing.assert_close(A, C2)
+ # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
+ # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
def test_normal_map_tree():
@@ -1474,146 +1339,6 @@ def test_normal_map_tree():
# print(pivots)
-@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
-@pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
-@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"])
-@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
-@pytest.mark.parametrize(
- "quant_storage",
- [torch.uint8, torch.float16, torch.bfloat16, torch.float32],
- ids=describe_dtype,
-)
-def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
- for dim in [128, 256, 512, 1024]:
- # for dim in [4*1024]:
- # for dim in [1*16]:
- errs1 = []
- errs2 = []
- errs3 = []
- relerrs1 = []
- relerrs2 = []
- relerrs3 = []
- max_errs1 = []
- max_errs2 = []
- max_errs3 = []
-
- for i in range(100):
- if kind == "fc1":
- A = torch.randn(1, dim, dtype=dtype, device="cuda")
- B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
- elif kind == "fc2":
- A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda")
- B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim)
- elif kind == "attn":
- A = torch.randn(1, dim, dtype=dtype, device="cuda")
- B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
- elif kind == "attn_packed":
- A = torch.randn(1, dim, dtype=dtype, device="cuda")
- B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim)
-
- qB, state = F.quantize_4bit(
- B,
- quant_type=storage_type,
- compress_statistics=double_quant,
- quant_storage=quant_storage,
- )
- C3 = torch.matmul(A, B.t())
- C2 = F.gemv_4bit(A, qB.t(), state=state)
- A.requires_grad = True
- C1 = bnb.matmul_4bit(A, qB.t(), state)
-
- err1 = (C1 - C2).abs().float()
- err2 = (C3 - C2).abs().float()
- err3 = (C3 - C1).abs().float()
-
- mag1 = torch.abs(C1).float() + 1e-5
- mag2 = torch.abs(C3).float() + 1e-5
- mag3 = torch.abs(C3).float() + 1e-5
-
- relerr1 = err1 / mag1
- relerr2 = err2 / mag2
- relerr3 = err3 / mag3
-
- max_err1 = err1.max()
- max_err2 = err2.max()
- max_err3 = err3.max()
-
- errs1.append(err1.mean().item())
- errs2.append(err2.mean().item())
- errs3.append(err3.mean().item())
-
- relerrs1.append(relerr1.mean().item())
- relerrs2.append(relerr2.mean().item())
- relerrs3.append(relerr3.mean().item())
-
- max_errs1.append(max_err1.item())
- max_errs2.append(max_err2.item())
- max_errs3.append(max_err3.item())
-
- c = int(C1.numel() * 0.0014 * (dim / 256)) + 1
-
- c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False)
- err1 = sum(errs1) / len(errs1) / math.sqrt(dim)
- err2 = sum(errs2) / len(errs2) / math.sqrt(dim)
- err3 = sum(errs3) / len(errs3) / math.sqrt(dim)
- relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim)
- relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim)
- relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim)
- maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim)
- maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim)
- maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim)
- absratio = err2 / err3
- relratio = relerr2 / relerr3
- maxratio = relerr2 / relerr3
-
- # for debugging if the tests fails
- #
- # print('='*80)
- # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
- # print(C1.flatten()[-20:])
- # print(C2.flatten()[-20:])
- # print(f'inference vs training abs: {err1}')
- # print(f'inference vs training rel: {relerr1}')
- # print(f'inference vs training max: {maxerr1}')
- # print(f'inference vs training vs torch err ratio abs: {absratio}')
- # print(f'inference vs training vs torch err ratio rel: {relratio}')
- # print(f'inference vs training vs torch err ratio max: {maxratio}')
- if dtype == torch.float16:
- if dim <= 512:
- assert err1 < 7e-5
- assert relerr1 < 0.0008
- else:
- assert err1 < 6e-5
- assert relerr1 < 2e-4
- assert absratio < 1.005 and absratio > 0.995
- assert relratio < 1.005 and relratio > 0.995
- assert maxratio < 1.005 and maxratio > 0.995
- elif dtype == torch.float32:
- if dim <= 512:
- assert err1 < 5e-8
- assert relerr1 < 1e-6
- assert maxerr1 < 1e-7
- else:
- assert err1 < 5e-8
- assert relerr1 < 8e-6
- assert maxerr1 < 1e-7
- assert absratio < 1.005 and absratio > 0.995
- assert relratio < 1.005 and relratio > 0.995
- assert maxratio < 1.005 and maxratio > 0.995
- elif dtype == torch.bfloat16:
- if dim <= 512:
- assert err1 < 6e-4
- assert relerr1 < 0.007
- assert maxerr1 < 0.015
- else:
- assert err1 < 2e-4
- assert relerr1 < 0.002
- assert maxerr1 < 0.0012
- assert absratio < 1.005 and absratio > 0.995
- assert relratio < 1.04 and relratio > 0.96
- assert maxratio < 1.02 and maxratio > 0.98
-
-
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():
n = 32 * 10
@@ -1637,32 +1362,6 @@ def test_managed():
assert (A == 17 * (2**3)).sum().item() == n * n
-@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
-@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
-@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
-def test_gemv_eye_4bit(storage_type, dtype, double_quant):
- dims = 10
- torch.random.manual_seed(np.random.randint(0, 412424242))
- dims = get_test_dims(0, 8192, n=dims)
- dims = [dim + (64 - (dim % 64)) for dim in dims]
- # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
- for dim in dims:
- A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda")
- B = torch.eye(dim, dtype=dtype, device="cuda")
-
- qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
- C3 = torch.matmul(A, B.t())
- C2 = bnb.matmul_4bit(A, qB.t(), state)
- A.requires_grad = True
- C1 = bnb.matmul_4bit(A, qB.t(), state)
-
- torch.testing.assert_close(A, C3)
- torch.testing.assert_close(A, C1)
- torch.testing.assert_close(A, C2)
- # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001)
- # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)
-
-
@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3"))
@@ -1676,169 +1375,3 @@ def test_vector_quant(dim1, dim2, dim3):
A1 = F.vectorwise_dequant(qA, SA)
n = A1.numel()
assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002))
-
-
-@pytest.mark.deprecated
-def test_quantile_quantization():
- for i in range(100):
- A1 = torch.randn(1024, 1024, device="cuda")
- code = F.estimate_quantiles(A1)
- C = F.quantize_no_absmax(A1, code)
- A2 = F.dequantize_no_absmax(C, code)
- diff = torch.abs(A1 - A2).mean().item()
- assert diff < 0.0075
-
- A1 = torch.rand(1024, 1024, device="cuda")
- code = F.estimate_quantiles(A1)
- C = F.quantize_no_absmax(A1, code)
- A2 = F.dequantize_no_absmax(C, code)
- diff = torch.abs(A1 - A2).mean().item()
- torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
- assert diff < 0.001
-
-
-@pytest.mark.deprecated
-def test_dynamic_quantization():
- diffs = []
- reldiffs = []
- for i in range(100):
- A1 = torch.randn(1024, 1024, device="cuda")
- C, S = F.quantize(A1)
- A2 = F.dequantize(C, S)
- diff = torch.abs(A1 - A2)
- reldiff = diff / torch.abs(A1 + 1e-8)
- diffs.append(diff.mean().item())
- reldiffs.append(reldiff.mean().item())
- assert diff.mean().item() < 0.0135
- print(sum(diffs) / len(diffs))
- print(sum(reldiffs) / len(reldiffs))
-
- for i in range(100):
- A1 = torch.rand(1024, 1024, device="cuda")
- C, S = F.quantize(A1)
- A2 = F.dequantize(C, S)
- diff = torch.abs(A1 - A2).mean().item()
- torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
- assert diff < 0.004
-
-
-@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
-@pytest.mark.deprecated
-def test_percentile_clipping(gtype):
- gnorm_vec1 = torch.zeros(100, device="cuda")
- gnorm_vec2 = torch.zeros(100, device="cuda")
- n = 4
- step = 0
- percentile = 5
- for i in range(k):
- step += 1
- g = torch.randn(n, n, dtype=gtype, device="cuda")
- gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
- assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
-
- gnorm2 = torch.norm(g.float())
- if step == 1:
- gnorm_vec1[:] = gnorm2
- else:
- gnorm_vec1[step % 100] = gnorm2
-
- vals, idx = torch.sort(gnorm_vec1)
- clip1 = vals[percentile]
-
- torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
- torch.testing.assert_close(clip1, clip2)
- torch.testing.assert_close(gnorm1, gnorm2)
-
-
-@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2"))
-@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3"))
-@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims"))
-@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype)
-@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
-@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut"))
-@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
-@pytest.mark.deprecated
-def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
- for i in range(k):
- if dims == 2:
- A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
- elif dims == 3:
- A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
-
- A.view(-1)[-1] = -1
- if transpose:
- At = A.t().contiguous()
- out1, S1 = F.nvidia_transform(At, to_order=orderOut)
- else:
- out1, S1 = F.nvidia_transform(A, to_order=orderOut)
- out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
-
- assert S1[0][0] == S2[0][0]
- assert S1[0][1] == S2[0][1]
- # print(out1)
- # print(out2)
-
- torch.testing.assert_close(out1, out2)
-
-
-@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1"))
-@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2"))
-@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3"))
-@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype)
-@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA"))
-@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut"))
-@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose"))
-@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims"))
-@pytest.mark.deprecated
-def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
- if dims == 3 and orderOut != "col32":
- return
- if dtype == torch.int32 and orderOut != "col32":
- return
- try:
- func = F.get_transform_func(dtype, orderA, orderOut, transpose)
- except ValueError as ve:
- pytest.skip(str(ve)) # skip if not supported
-
- if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
- elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
-
- out, S = F.nvidia_transform(A, to_order=orderOut)
-
- if orderOut == "row":
- torch.testing.assert_close(A.flatten(), out.flatten())
- elif orderOut == "col":
- torch.testing.assert_close(A.t().flatten(), out.flatten())
- elif orderOut == "col32":
- if dims == 2:
- n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
- elif dims == 3:
- n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
- assert out.numel() == n
- elif orderOut == "col_turing":
- # 32 col 8 row tiles
- n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32)))
- assert out.numel() == n
- total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
- for row in range(A.shape[0]):
- for col in range(A.shape[1]):
- i = row * A.shape[1]
- j = col
-
- coltile = (col // 32) + (1 if col % 32 != 0 else 0)
- rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
- offset = 32 * 8 * (rowtile + coltile)
- col2 = col % 32
- row2 = (row % 8) * 32
-
- assert A.flatten()[i + j] == A[row, col]
- # assert A.flatten()[i+j] == out.flatten()[row2+col2]
- # torch.testing.assert_close(A.flatten()[i+j], A[row, col])
- # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
-
- if orderOut == "col32":
- out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
- torch.testing.assert_close(A, out2)
diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py
index bc9e2600f..a9efa796f 100644
--- a/tests/test_linear8bitlt.py
+++ b/tests/test_linear8bitlt.py
@@ -8,8 +8,6 @@
import torch
import bitsandbytes as bnb
-from bitsandbytes import functional as F
-from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import (
TRUE_FALSE,
@@ -18,28 +16,9 @@
torch_save_to_buffer,
)
+
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
-
-
-@pytest.mark.skipif(
- not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
- reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
-)
-def test_layout_exact_match():
- x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
- for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
- transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
- tile_indices = get_inverse_transform_indices(transform, tile_size)
- cxb = transform(x)
-
- torch.cuda.synchronize()
- restored_x = undo_layout(cxb, tile_indices)
- torch.cuda.synchronize()
- assert restored_x.is_contiguous()
- assert torch.all(torch.eq(restored_x, x))
-
-
def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half)
diff --git a/tests/test_ops.py b/tests/test_ops.py
new file mode 100644
index 000000000..0e46ff9e0
--- /dev/null
+++ b/tests/test_ops.py
@@ -0,0 +1,221 @@
+from math import prod
+
+import pytest
+import torch
+
+import bitsandbytes
+from tests.helpers import TRUE_FALSE, id_formatter
+
+
+class TestLLMInt8Ops:
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ def test_int8_linear_matmul(self, device):
+ A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
+ B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
+ out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
+
+ assert out.shape == (10, 30)
+ assert out.dtype == torch.int32
+ assert out.device == A.device
+
+ torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B))
+
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ def test_int8_linear_matmul_out(self, device):
+ A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
+ B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
+
+ out = torch.empty((10, 30), dtype=torch.int32, device=device)
+ torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out)
+
+ assert out.shape == (10, 30)
+ assert out.dtype == torch.int32
+ assert out.device == A.device
+
+ torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out))
+
+ @pytest.mark.parametrize("threshold", [0.0, 6.0])
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ def test_int8_vectorwise_quant(self, threshold, device):
+ if device == "cpu":
+ pytest.skip("CPU implementation is not available")
+
+ A = torch.randn(10, 20, dtype=torch.float16, device=device)
+ A[1][0] = 1000.0
+
+ out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold)
+
+ assert out_row.shape == (10, 20)
+ assert out_row.dtype == torch.int8
+ assert out_row.device == A.device
+ assert row_stats.shape == (10,)
+ assert row_stats.dtype == torch.float32
+ assert row_stats.device == A.device
+
+ if threshold > 0.0:
+ assert outlier_cols is not None
+ assert outlier_cols.dim() == 1
+ assert outlier_cols.shape[0] <= A.shape[1]
+ assert outlier_cols.device == A.device
+ else:
+ assert outlier_cols is None
+
+ torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,))
+
+ torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold))
+
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ def test_int8_mm_dequant(self, device):
+ A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device)
+ row_stats = torch.randn(256, dtype=torch.float32, device=device)
+ col_stats = torch.randn(256, dtype=torch.float32, device=device)
+ out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats)
+
+ assert out.shape == A.shape
+ assert out.dtype == torch.float16
+ assert out.device == A.device
+
+ torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats))
+
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
+ @pytest.mark.parametrize("has_bias", TRUE_FALSE)
+ def test_int8_scaled_mm(self, device, dtype, has_bias):
+ A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device)
+ B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device)
+ row_stats = torch.randn(10, dtype=torch.float32, device=device)
+ col_stats = torch.randn(30, dtype=torch.float32, device=device)
+ bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None
+ out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype)
+
+ assert out.shape == (10, 30)
+ assert out.dtype == dtype
+ assert out.device == A.device
+
+ torch.library.opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype))
+
+
+class TestInt8BlockwiseQuantOps:
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
+ @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
+ def test_quantize_blockwise(self, device, dtype, blocksize):
+ if device == "cpu" and dtype != torch.float32:
+ pytest.skip("CPU implementation is only available for float32")
+
+ code = bitsandbytes.functional.create_dynamic_map().to(device)
+ A = torch.randn(1024, 1024, dtype=dtype, device=device)
+ out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)
+
+ assert out.shape == A.shape
+ assert out.dtype == torch.uint8
+ assert out.device == A.device
+
+ assert absmax.device == A.device
+ assert absmax.dtype == torch.float32
+
+ torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize))
+
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
+ @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
+ def test_dequantize_blockwise(self, device, dtype, blocksize):
+ if device == "cpu" and dtype != torch.float32:
+ pytest.skip("CPU implementation is only available for float32")
+
+ A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device)
+ code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)
+
+ n = A.numel()
+ blocks = -(n // -blocksize)
+ absmax = torch.randn((blocks,), device=device, dtype=torch.float32)
+
+ out = torch.ops.bitsandbytes.dequantize_blockwise.default(A, absmax, code, blocksize, dtype)
+
+ assert out.shape == A.shape
+ assert out.dtype == dtype
+ assert out.device == A.device
+
+ torch.library.opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
+
+
+class Test4bitBlockwiseQuantOps:
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
+ @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
+ @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
+ @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
+ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
+ if device == "cpu":
+ pytest.skip("CPU implementation is not available")
+
+ A = torch.randn(1024, 1024, dtype=dtype, device=device)
+
+ out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype)
+
+ assert out.device == A.device
+ assert out.dtype == storage_dtype
+
+ assert absmax.device == A.device
+ assert absmax.dtype == torch.float32
+
+ torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
+
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
+ @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
+ @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
+ @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
+ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
+ if device == "cpu":
+ pytest.skip("CPU implementation is not available")
+
+ shape = (128, 128)
+
+ n = prod(shape)
+ blocks = -(n // -blocksize)
+ quantized_shape = ((n + 1) // (storage_dtype.itemsize * 2), 1)
+
+ A = (
+ torch.randint(0, 255, ((n + 1) // 2,), dtype=torch.uint8, device=device)
+ .view(storage_dtype)
+ .reshape(quantized_shape)
+ .contiguous()
+ )
+
+ absmax = torch.randn((blocks,), dtype=torch.float32, device=device)
+
+ out = torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)
+
+ assert out.device == A.device
+ assert out.shape == shape
+
+ torch.library.opcheck(
+ torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype)
+ )
+
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
+ @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
+ @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
+ @pytest.mark.parametrize("blocksize", [64, 128, 256, 512])
+ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
+ if device == "cpu":
+ pytest.skip("CPU implementation is not available")
+
+ out_features = 1024
+ in_features = 256
+
+ A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
+ B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
+ B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype)
+ code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize)
+
+ out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize)
+
+ assert out.device == A.device
+ assert out.dtype == dtype
+ assert out.shape == (1, 1, out_features)
+ assert out.isreal().all()
+
+ torch.library.opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))