diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index c549433c4..f69125aaf 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,15 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import research, utils + +from . import _ops, research, utils from .autograd._functions import ( MatmulLtState, - bmm_cublas, matmul, matmul_4bit, - matmul_cublas, - mm_cublas, ) +from .backends.cpu import ops as cpu_ops +from .backends.cuda import ops as cuda_ops ## TODO: We would guard this for CUDA only from .nn import modules from .optim import adam diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py new file mode 100644 index 000000000..c519194a6 --- /dev/null +++ b/bitsandbytes/_ops.py @@ -0,0 +1,316 @@ +from math import prod +from typing import Optional, Sequence, Tuple + +import torch + +_IS_TORCH_GTE_24 = False + +if hasattr(torch.library, "register_fake"): + _IS_TORCH_GTE_24 = True + register_fake = torch.library.register_fake + register_kernel = torch.library.register_kernel +else: + # PyTorch <= 2.3 + register_fake = torch.library.impl_abstract + register_kernel = torch.library.impl + + +# Higher level op: int8 matmul + dequant + bias +torch.library.define( + "bitsandbytes::int8_scaled_mm", + "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor", +) + + +@register_fake("bitsandbytes::int8_scaled_mm") +def _( + A: torch.Tensor, + B: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: Optional[torch.Tensor] = None, + dtype=torch.float16, +) -> torch.Tensor: + shapeC = (*A.shape[:-1], B.shape[0]) + return torch.empty(shapeC, device=A.device, dtype=dtype) + + +@register_kernel("bitsandbytes::int8_scaled_mm", None) +def _( + A: torch.Tensor, + B: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: Optional[torch.Tensor] = None, + dtype=torch.float16, +) -> torch.Tensor: + out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) + out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias) + return out + + +torch.library.define( + "bitsandbytes::int8_linear_matmul", + "(Tensor A, Tensor B) -> Tensor", +) + + +@register_fake("bitsandbytes::int8_linear_matmul") +def _(A: torch.Tensor, B: torch.Tensor): + torch._check(A.dtype == torch.int8, lambda: "A must be int8") + torch._check(B.dtype == torch.int8, lambda: "B must be int8") + shapeC = (*A.shape[:-1], B.shape[0]) + return torch.empty(shapeC, device=A.device, dtype=torch.int32) + + +# More info on `out` overloads: +# https://github.com/pytorch/pytorch/issues/125044 +torch.library.define( + "bitsandbytes::int8_linear_matmul.out", + "(Tensor A, Tensor B, Tensor! out) -> ()", +) + + +@register_fake("bitsandbytes::int8_linear_matmul.out") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + shapeC = (*A.shape[:-1], B.shape[0]) + + torch._check(A.dtype == torch.int8, lambda: "A must be int8") + torch._check(B.dtype == torch.int8, lambda: "B must be int8") + torch._check(out.shape == shapeC, lambda: f"Expected out.shape == {shapeC}, got {out.shape}") + torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") + torch._check(out.dtype == torch.int32, lambda: f"Expected out.dtype == int32, got {out.dtype}") + + +torch.library.define( + "bitsandbytes::int8_vectorwise_quant", + "(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)", +) + + +@register_fake("bitsandbytes::int8_vectorwise_quant") +def _(A: torch.Tensor, threshold=0.0): + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32) + + if threshold == 0.0: + return out_row, row_stats, None + + outlier_cols = torch.library.get_ctx().new_dynamic_size() + + return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64) + + +torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor") + + +@register_fake("bitsandbytes::int8_vectorwise_dequant") +def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor: + torch._check(A.dtype == torch.int8, lambda: "A must be int8") + return torch.empty_like(A, dtype=torch.float32) + + +# Default PyTorch-native implementation +@register_kernel("bitsandbytes::int8_vectorwise_dequant", None) +def _(A: torch.Tensor, stats: torch.Tensor): + # To dequantize we divide by 127, or multiply by the reciprocal. + return A * stats.view(-1, 1) * 7.874015718698502e-3 + + +torch.library.define( + "bitsandbytes::int8_mm_dequant", + "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor", +) + + +@register_fake("bitsandbytes::int8_mm_dequant") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype=torch.float16, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: "A must be int32") + return torch.empty_like(A, dtype=dtype) + + +torch.library.define( + "bitsandbytes::int8_double_quant", + "(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)", +) + + +@register_fake("bitsandbytes::int8_double_quant") +def _( + A: torch.Tensor, + threshold=0.0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + out_row = torch.empty_like(A, dtype=torch.int8) + out_col = torch.empty_like(A, dtype=torch.int8) + row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32) + col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32) + outlier_n = torch.library.get_ctx().new_dynamic_size() + outlier_cols = A.new_empty(outlier_n, dtype=torch.int64) + return out_row, out_col, row_stats, col_stats, outlier_cols + + +torch.library.define( + "bitsandbytes::dequantize_4bit", + "(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor", +) + + +@register_fake("bitsandbytes::dequantize_4bit") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check_is_size(blocksize) + return torch.empty(shape, dtype=dtype, device=A.device) + + +torch.library.define( + "bitsandbytes::dequantize_4bit.out", + "(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()", +) + + +@register_fake("bitsandbytes::dequantize_4bit.out") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + + +torch.library.define( + "bitsandbytes::quantize_4bit", + "(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)", +) + + +@register_fake("bitsandbytes::quantize_4bit") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + return out, absmax + + +torch.library.define( + "bitsandbytes::dequantize_blockwise", + "(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor", +) + + +@register_fake("bitsandbytes::dequantize_blockwise") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + return torch.empty_like(A, dtype=dtype) + + +torch.library.define( + "bitsandbytes::dequantize_blockwise.out", + "(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()", +) + + +@register_fake("bitsandbytes::dequantize_blockwise.out") +def _( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +): + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + + +torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)") + + +@register_fake("bitsandbytes::quantize_blockwise") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + return out, absmax + + +torch.library.define( + "bitsandbytes::gemv_4bit", + "(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor", +) + + +@register_fake("bitsandbytes::gemv_4bit") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + shape = (*A.shape[:-1], shapeB[0]) + return torch.empty(shape, device=A.device, dtype=A.dtype) + + +torch.library.define( + "bitsandbytes::gemv_4bit.out", + "(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()", +) + + +@register_fake("bitsandbytes::gemv_4bit.out") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index f66cdf68d..9f14db754 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -49,6 +49,10 @@ def get_current_outlier_idx(self): return torch.Tensor(list(self.outliers)).to(torch.int64) +@deprecated( + "This function is deprecated and will be removed in a future release.", + category=FutureWarning, +) def get_inverse_transform_indices( transform_tile: Callable[[torch.Tensor], torch.Tensor], tile_size: Tuple[int, int], @@ -80,6 +84,10 @@ def get_inverse_transform_indices( return permuted_tile_indices +@deprecated( + "This function is deprecated and will be removed in a future release.", + category=FutureWarning, +) def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: """ Undo a tiled permutation such as turing or ampere layout @@ -98,152 +106,9 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - return outputs.reshape(rows, cols).contiguous() -@deprecated( - "MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.", - category=FutureWarning, -) -class MatMul8bit(torch.autograd.Function): - @staticmethod - def forward(ctx, A, B, out=None, quant_type="vector", precision=None): - if precision is None: - precision = [8, 8, 8] - if precision[0] != 8: - with torch.no_grad(): - output = torch.matmul(A, B) - else: - if len(B.shape) == 2: - dim = 0 - else: - dim = 1 - qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type) - qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type) - iout = F.igemm(qA, qB) - output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type) - - if A.requires_grad or B.requires_grad: - ctx.save_for_backward(A, B) - - ctx.quant_type = quant_type - ctx.precision = precision - - return output - - @staticmethod - def backward(ctx, grad_output): - A, B = ctx.saved_tensors - quant_type = ctx.quant_type - precision = ctx.precision - grad_A = grad_B = None - - if B.requires_grad: - if len(A.shape) == 3: - dims = [0, 1] - # bsi -> ibs - permute_dim = [0, 2, 1] - else: - dims = [0] - # bs -> sb - permute_dim = [1, 0] - - if precision[1] != 8: - with torch.no_grad(): - grad_B = torch.matmul(A.permute(permute_dim), grad_output) - else: - if len(B.shape) == 2 and len(A.shape) == 3: - grad_output = grad_output.contiguous() - if not grad_output.is_contiguous(): - grad_output.contiguous() - qgrad_output, S1 = F.vectorwise_quant( - grad_output.view(-1, grad_output.shape[2]), - dim=0, - quant_type=quant_type, - ) - if not A.is_contiguous(): - A = A.contiguous() - qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) - igrad_B = F.igemm(qA.t(), qgrad_output) - grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) - else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) - qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) - igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) - grad_B = F.vectorwise_mm_dequant( - igrad_B, - S2.permute(permute_dim), - S1, - grad_output.dtype, - quant_type, - ) - - if A.requires_grad: - if len(grad_output.shape) == 3: - dims = [2] - else: - dims = [1] - - if len(B.shape) == 3: - # bio -> boi - permute_dim = [0, 2, 1] - dim_B = dims - else: - # io -> oi - permute_dim = [1, 0] - dim_B = [1] - - if precision[2] != 8: - with torch.no_grad(): - grad_A = torch.matmul(grad_output, B.permute(permute_dim)) - else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) - qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) - igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) - grad_A = F.vectorwise_mm_dequant( - igrad_A, - S1, - S3.permute(permute_dim), - grad_output.dtype, - quant_type, - ) - - return grad_A, grad_B, None, None, None - - -mm_cublas = MatMul8bit.apply -bmm_cublas = MatMul8bit.apply -matmul_cublas = MatMul8bit.apply - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def supports_igemmlt(device: torch.device) -> bool: - """check if this device supports the optimized int8 kernel""" - if torch.cuda.get_device_capability(device=device) < (7, 5): - return False - device_name = torch.cuda.get_device_name(device=device) - nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series - if any(model_name in device_name for model_name in nvidia16_models): - return False # these devices are technically cuda 7.5-capable, but they lack tensor cores - return True - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def _get_tile_size(format): - assert format in ( - "col_turing", - "col_ampere", - ), f"please find this assert and manually enter tile size for {format}" - return (8, 32) if format == "col_turing" else (32, 32) - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def get_tile_inds(format, device): - transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device) - with torch.no_grad(): - return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) - - @dataclass class MatmulLtState: - _tile_indices: Optional[torch.Tensor] = None + _tile_indices: Optional[torch.Tensor] = None # TODO: remove force_no_igemmlt: bool = False @@ -279,9 +144,7 @@ def reset_grads(self): @property def tile_indices(self): - if self._tile_indices is None: - self._tile_indices = get_tile_inds(self.formatB, self.CxB.device) - return self._tile_indices + raise ValueError("tile_indices is no longer supported.") class MatMul8bitLt(torch.autograd.Function): @@ -360,20 +223,12 @@ def forward( # we want to divide by 127. It's however more performant to multiply # by the reciprocal. outliers = state.CB[:, state.idx] - state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype) + state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t() else: subA = None - # 3. Int8 Matmul - out32 = F.int8_linear_matmul(CA, state.CB) - - # Dequantize matmul result - if bias is None or bias.dtype == torch.float16: - # we apply the fused bias here - output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype) - else: # apply bias separately - # TODO: Fused bias for fp32/bf16? - output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias) + # 3. Int8 Matmul + Dequant + Bias + output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype) # 4. Mixed-precision decomposition matmul if subA is not None and state.subB is not None: @@ -423,8 +278,14 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor if req_gradB: Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16)) - gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t()) - grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt) + grad_B = torch.ops.bitsandbytes.int8_scaled_mm.default( + Cgrad.t().contiguous(), + CAt.t(), + SCgradt, + SCAt, + dtype=torch.float16, + ) + if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/cpu/__init__.py b/bitsandbytes/backends/cpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py new file mode 100644 index 000000000..07e3cafd4 --- /dev/null +++ b/bitsandbytes/backends/cpu/ops.py @@ -0,0 +1,94 @@ +import ctypes as ct +from typing import Optional, Tuple + +import torch + +from bitsandbytes.functional import get_ptr + +from ..._ops import register_kernel +from ...cextension import lib + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cpu") +def _(A: torch.Tensor, B: torch.Tensor): + return _int8_linear_matmul_impl(A, B) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cpu") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + torch._check(out.dtype == torch.int32) + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None): + # Naive implementation: perform matmul in fp32 + result = torch.matmul(A.float(), B.float().t()).to(torch.int32) + if out is not None: + result = out.copy_(result) + return result + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cpu") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype=torch.float16, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + A_calc = A.view(-1, A.shape[-1]) + row_stats = row_stats.reshape(-1).unsqueeze(-1) + col_stats = col_stats.reshape(-1).unsqueeze(0) + + out = A_calc * (row_stats * col_stats) * 6.200124e-05 + if bias is not None: + out += bias + + return out.to(dtype) + + +@register_kernel("bitsandbytes::quantize_blockwise", "cpu") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(n), + ) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cpu") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}") + + out = torch.empty_like(A, dtype=dtype) + + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + + return out diff --git a/bitsandbytes/backends/cuda/__init__.py b/bitsandbytes/backends/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py new file mode 100644 index 000000000..88bd872be --- /dev/null +++ b/bitsandbytes/backends/cuda/ops.py @@ -0,0 +1,520 @@ +import ctypes as ct +from math import prod +from typing import Optional, Sequence, Tuple + +import torch + +from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr + +from ..._ops import register_kernel +from ...cextension import lib + + +@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") +def _(A: torch.Tensor, B: torch.Tensor): + out = torch.empty((*A.shape[:-1], B.shape[0]), device=A.device, dtype=torch.int32) + return _int8_linear_matmul_impl(A, B, out) + + +@register_kernel("bitsandbytes::int8_linear_matmul.out", "cuda") +def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + _int8_linear_matmul_impl(A, B, out) + + +def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + A, B = B, A + + shapeA = A.shape + shapeB = B.shape + + torch._check(A.dtype == torch.int8, lambda: "B must be int8") + torch._check(B.dtype == torch.int8, lambda: "A must be int8") + torch._check(A.ndim == 2, lambda: "Only two dimensional matrices are supported for argument B") + torch._check(B.ndim in [2, 3], lambda: "Only two or three dimensional matrices are supported for argument A") + torch._check(prod(shapeB) > 0, lambda: f"Input tensor dimensions need to be > 0: {shapeB}") + torch._check(out.dtype == torch.int32) + + shapeC = (*shapeB[:-1], shapeA[0]) + torch._check(out.shape == shapeC, lambda: f"Output shape {out.shape} does not match expected shape {shapeC}") + + k, m = shapeA + n = prod(shapeB[:-1]) + lda = shapeA[-1] # Weights (outputs, inputs) + ldb = shapeB[-1] # Activations (batch, tokens, inputs) + ldc = shapeC[-1] # Output (batch, tokens, outputs) + + torch._check( + lda == ldb, + lambda: f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}", + ) + + # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. + # We'll fall back to a slower fp32 calculation in this circumstance. + # Fortunately, this should not be very common. + if lda % 4 != 0: + result = torch.matmul(B.float(), A.float().t()).to(torch.int32) + return out.copy_(result) + + with _cuda_device_of(A): + ctx = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = None + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + stream = _get_tensor_stream(A) + + has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) + + if has_error: + if has_error == 100: + # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` + # TODO: Warn and implement a fallback to fp32 compute? + raise NotImplementedError("int8_linear_matmul not implemented!") + else: + raise RuntimeError( + f"cublasLt ran into an error!\n" + f"\t{shapeA=}, {shapeB=}, {shapeC=}\n" + f"\t{(lda, ldb, ldc)=}\n" + f"\t{(m, n, k)=}" + ) + + return out + + +@register_kernel("bitsandbytes::int8_mm_dequant", "cuda") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype=torch.float16, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + # Note: cuda kernel only currently supports fp16 output. + # We'll later cast to desired dtype if needed. + out = torch.empty_like(A, dtype=torch.float16) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + numRows = ct.c_int32(prod(A.shape[:-1])) + numCols = ct.c_int32(A.shape[-1]) + + # Note: fused bias in the kernel is only supported for fp16 + # TODO(matthewdouglas): Consider supporting bf16 fused bias + ptrBias = get_ptr(bias) if bias is not None and bias.dtype == torch.float16 else None + + with _cuda_device_of(A): + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) + ) + + # Add bias separately if not fused in kernel + if bias is not None and bias.dtype != torch.float16: + out.add_(bias) + + return out.to(dtype) + + +@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") +def _(A: torch.Tensor, threshold=0.0): + torch._check(A.dtype == torch.float16, lambda: f"A must be float16, got {A.dtype}") + torch._check(threshold >= 0.0, lambda: "threshold must be non-negative") + + rows = prod(A.shape[:-1]) + cols = A.shape[-1] + + row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) + out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) + + outlier_cols = None + + if threshold > 0.0: + # TODO we could improve perf of this + outliers = A.abs() >= threshold + + if outliers.any(): + outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + + with _cuda_device_of(A): + lib.cint8_vector_quant( + get_ptr(A), + get_ptr(out_row), + get_ptr(row_stats), + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + _get_tensor_stream(A), + ) + + # Zero out values from outlier columns across all rows. + # The kernel will handle this for outliers themselves, so we can optimize for rows=1. + if rows > 1 and outlier_cols is not None: + out_row[:, outlier_cols] = 0 + + return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::int8_double_quant", "cuda") +def _( + A: torch.Tensor, + threshold=0.0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant.default( + A, + threshold=threshold, + ) + + # PyTorch impl for colwise + col_stats, outlier_mask = _get_col_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(127.0) / col_stats.unsqueeze(0)).to(torch.int8) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + + +def _get_col_absmax( + A: torch.Tensor, + threshold=0.0, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + torch._check(A.is_floating_point()) + + outlier_mask = None + + absA = A.abs().view(-1, A.shape[-1]) + + if threshold > 0.0: + # Filter outliers from stats when enabled + outlier_mask = absA >= threshold + absA.masked_fill_(outlier_mask, 0.0) + + # shape [cols]; unsqueeze(0) gives [1,cols] + col_stats = absA.amax(dim=0, keepdim=False).float() + + return col_stats, outlier_mask + + +@register_kernel("bitsandbytes::quantize_blockwise", "cuda") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(A.numel()), + ) + + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "cuda") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + out = torch.empty_like(A, dtype=dtype) + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_blockwise.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") + _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out) + + +def _dequantize_blockwise_impl( + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + torch._check( + dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"Blockwise dequantization only supports 16bit/32bit floating types, got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + _get_tensor_stream(A), + ) + + if dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + + +@register_kernel("bitsandbytes::quantize_4bit", "cuda") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> Tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_4bit", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + out = torch.empty(shape, dtype=dtype, device=A.device) + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + return out + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "cuda") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(quant_type in ["fp4", "nf4"]) + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + with _cuda_device_of(A): + args = ( + None, + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(out.numel()), + _get_tensor_stream(A), + ) + + if out.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4(*args) + else: + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + + +@register_kernel("bitsandbytes::gemv_4bit", "cuda") +def _( + A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int +) -> torch.Tensor: + shape = (*A.shape[:-1], shapeB[0]) + out = torch.empty(shape, device=A.device, dtype=A.dtype) + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + return out + + +@register_kernel("bitsandbytes::gemv_4bit.out", "cuda") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check( + out.shape == (*A.shape[:-1], shapeB[0]), + lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + ) + torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + torch._check_is_size(blocksize) + torch._check( + A.numel() == A.size(-1), + lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", + ) + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + + m = ct.c_int32(shapeB[0]) + n = ct.c_int32(1) + k = ct.c_int32(shapeB[1]) + + lda = m + ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldc = m + + stream = _get_tensor_stream(A) + + with _cuda_device_of(A): + if A.dtype == torch.float16: + lib.cgemm_4bit_inference_naive_fp16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.bfloat16: + lib.cgemm_4bit_inference_naive_bf16( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) + elif A.dtype == torch.float32: + lib.cgemm_4bit_inference_naive_fp32( + m, + n, + k, + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(code), + get_ptr(out), + lda, + ldb, + ldc, + ct.c_int32(blocksize), + stream, + ) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 5894e52dd..c0e139e03 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -182,13 +182,6 @@ def get_instance(cls): return cls._instance -dtype2bytes = {} -dtype2bytes[torch.float32] = 4 -dtype2bytes[torch.float16] = 2 -dtype2bytes[torch.bfloat16] = 2 -dtype2bytes[torch.uint8] = 1 -dtype2bytes[torch.int8] = 1 - FIRST_CUDA_DEVICE = torch.device("cuda", index=0) # When multiple GPUs are present, we use a context manager to @@ -207,7 +200,7 @@ def _cuda_device_of(a: torch.Tensor): def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): - num_bytes = dtype2bytes[dtype] * prod(shape) + num_bytes = dtype.itemsize * prod(shape) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) @@ -217,15 +210,14 @@ def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): return out -def prefetch_tensor(A, to_cpu=False): +def prefetch_tensor(A: torch.Tensor, to_cpu=False): assert A.is_paged, "Only paged tensors can be prefetched!" if to_cpu: deviceid = -1 else: deviceid = A.page_deviceid - num_bytes = dtype2bytes[A.dtype] * A.numel() - lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + lib.cprefetch(get_ptr(A), ct.c_size_t(A.nbytes), ct.c_int32(deviceid)) def elementwise_func(func_name, A, B, value, prefetch=True): @@ -431,11 +423,6 @@ def create_quantile_map(A, total_bits=8): return q -@deprecated("This function is deprecated and will be removed in a future version.", category=FutureWarning) -def get_special_format_str(): - return "row" - - def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): """Verifies that the input tensors are all on the same device. @@ -472,11 +459,6 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): return on_gpu -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: - return torch.cuda.current_stream(tensor.device) - - def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p: # We use the raw stream for performance reasons. return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index)) @@ -509,106 +491,6 @@ def post_call(prev_device): torch.cuda.set_device(prev_device) -@deprecated( - "The layout transformation operations will be removed in a future release. Please use row-major layout only.", - category=FutureWarning, -) -def get_transform_func(dtype, orderA, orderOut, transpose=False): - name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' - if not hasattr(lib, name): - print(name) - raise ValueError( - f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}", - ) - else: - return getattr(lib, name) - - -@deprecated( - "The layout transformation operations will be removed in a future release. Please use row-major layout only.", - category=FutureWarning, -) -def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False): - # init_func = torch.empty - init_func = torch.zeros - dims = len(shape) - - if dims == 2: - rows = shape[0] - elif dims == 3: - rows = shape[0] * shape[1] - cols = shape[-1] - - state = (shape, to_order) - if transpose: - # swap dims - tmp = rows - rows = cols - cols = tmp - state = (shape[::-1], to_order) - - if to_order == "row" or to_order == "col": - return init_func(shape, dtype=dtype, device=device), state - elif to_order == "col32": - # blocks of 32 columns (padded) - cols = 32 * ((cols + 31) // 32) - return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == "col_turing": - # blocks of 32 columns and 8 rows - cols = 32 * ((cols + 31) // 32) - rows = 8 * ((rows + 7) // 8) - return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == "col_ampere": - # blocks of 32 columns and 32 rows - cols = 32 * ((cols + 31) // 32) - rows = 32 * ((rows + 31) // 32) - return init_func((rows, cols), dtype=dtype, device=device), state - else: - raise NotImplementedError(f"To_order not supported: {to_order}") - - -@deprecated( - "The layout transformation operations will be removed in a future release. Please use row-major layout only.", - category=FutureWarning, -) -def nvidia_transform( - A, - to_order, - from_order="row", - out=None, - transpose=False, - state=None, - ld=None, -): - if state is None: - state = (A.shape, from_order) - else: - from_order = state[1] - if out is None: - out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) - else: - new_state = (state[1], to_order) - func = get_transform_func(A.dtype, from_order, to_order, transpose) - - shape = state[0] - if len(shape) == 2: - dim1 = ct.c_int32(shape[0]) - dim2 = ct.c_int32(shape[1]) - elif ld is not None: - n = prod(shape) - dim1 = prod([shape[i] for i in ld]) - dim2 = ct.c_int32(n // dim1) - dim1 = ct.c_int32(dim1) - else: - dim1 = ct.c_int32(shape[0] * shape[1]) - dim2 = ct.c_int32(shape[2]) - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) - - return out, new_state - - def estimate_quantiles( A: Tensor, out: Optional[torch.Tensor] = None, @@ -892,56 +774,16 @@ def quantize_blockwise( name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - if absmax is None: - n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) - - if A.device.type != "cpu": - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - code = code.to(A.device) - - is_on_gpu([A, out, absmax]) - - with _cuda_device_of(A): - args = ( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(A.numel()), - ) - - if A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(*args) - elif A.dtype == torch.bfloat16: - lib.cquantize_blockwise_bf16(*args) - elif A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - else: - # cpu - code = code.cpu() - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) + _out, _absmax = torch.ops.bitsandbytes.quantize_blockwise.default( + A, + code.to(A.device), + blocksize, + ) if nested: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + offset = _absmax.mean() + _absmax -= offset + qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) quant_state = QuantState( absmax=qabsmax, code=code, @@ -951,7 +793,14 @@ def quantize_blockwise( state2=state2, ) else: - quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=A.dtype) + quant_state = QuantState(absmax=_absmax, code=code, blocksize=blocksize, dtype=A.dtype) + + # TODO(matthewdouglas): Deprecate out kwarg + out = out.copy_(_out) if out is not None else _out + + # TODO(matthewdouglas): Deprecate absmax kwarg + if absmax is not None: + quant_state.absmax = absmax.copy_(quant_state.absmax) return out, quant_state @@ -1013,49 +862,24 @@ def dequantize_blockwise( if absmax.dtype != torch.float32: absmax = absmax.float() - if out is None: - out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) - - if A.device.type != "cpu": - code = quant_state.code.to(A.device) - if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]: - raise ValueError( - f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]", - ) - - is_on_gpu([A, absmax, out]) - - with _cuda_device_of(A): - args = ( - get_ptr(quant_state.code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(A.numel()), - _get_tensor_stream(A), - ) - - if out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(*args) - elif out.dtype == torch.bfloat16: - lib.cdequantize_blockwise_bf16(*args) - elif out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - else: - code = quant_state.code.cpu() - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(quant_state.blocksize), - ct.c_longlong(A.numel()), + if out is not None: + torch.ops.bitsandbytes.dequantize_blockwise.out( + A, + absmax, + code.to(A.device), + blocksize, + quant_state.dtype, + out=out, ) + return out - return out + return torch.ops.bitsandbytes.dequantize_blockwise.default( + A, + absmax, + quant_state.code.to(A.device), + quant_state.blocksize, + quant_state.dtype, + ) def get_4bit_type(typename, device=None, blocksize=64): @@ -1194,62 +1018,21 @@ def quantize_4bit( - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - - if A.device.type != "cuda": - raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") - - n = A.numel() input_shape = A.shape - if absmax is None: - blocks = -(n // -blocksize) - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - - if out is None: - mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) - - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - is_on_gpu([A, out, absmax]) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - - if A.dtype == torch.bfloat16: - if quant_type == "fp4": - lib.cquantize_blockwise_bf16_fp4(*args) - else: - lib.cquantize_blockwise_bf16_nf4(*args) - elif A.dtype == torch.float16: - if quant_type == "fp4": - lib.cquantize_blockwise_fp16_fp4(*args) - else: - lib.cquantize_blockwise_fp16_nf4(*args) - elif A.dtype == torch.float32: - if quant_type == "fp4": - lib.cquantize_blockwise_fp32_fp4(*args) - else: - lib.cquantize_blockwise_fp32_nf4(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( + A, + blocksize, + quant_type, + quant_storage, + ) code = get_4bit_type(quant_type, device=A.device) if compress_statistics: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) - del absmax + offset = _absmax.mean() + qabsmax, state2 = quantize_blockwise(_absmax - offset, blocksize=256) + del _absmax state = QuantState( absmax=qabsmax, shape=input_shape, @@ -1262,7 +1045,7 @@ def quantize_4bit( ) else: state = QuantState( - absmax=absmax, + absmax=_absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, @@ -1270,6 +1053,13 @@ def quantize_4bit( quant_type=quant_type, ) + # TODO(matthewdouglas): Deprecate out kwarg + out = out.copy_(_out) if out is not None else _out + + # TODO(matthewdouglas): Deprecate absmax kwarg + if absmax is not None: + state.absmax = absmax.copy_(state.absmax) + return out, state @@ -1327,14 +1117,6 @@ def dequantize_4bit( Returns: `torch.Tensor`: The dequantized tensor. """ - - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", - ) - if quant_type not in ["fp4", "nf4"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") - if quant_state is None: assert absmax is not None and out is not None @@ -1355,42 +1137,19 @@ def dequantize_4bit( if absmax.dtype != torch.float32: absmax = absmax.float() - if out is None: - out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) - - n = out.numel() - - is_on_gpu([A, absmax, out]) - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - args = ( - None, - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, + if out is not None: + torch.ops.bitsandbytes.dequantize_4bit.out( + A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out + ) + else: + out = torch.ops.bitsandbytes.dequantize_4bit.default( + A, + absmax, + quant_state.blocksize, + quant_state.quant_type, + quant_state.shape, + quant_state.dtype, ) - - if out.dtype == torch.bfloat16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4(*args) - else: - lib.cdequantize_blockwise_bf16_nf4(*args) - elif out.dtype == torch.float16: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp16_fp4(*args) - else: - lib.cdequantize_blockwise_fp16_nf4(*args) - elif out.dtype == torch.float32: - if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_fp32_fp4(*args) - else: - lib.cdequantize_blockwise_fp32_nf4(*args) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") if A.shape[0] == 1: # is transposed, transpose back return out.t() @@ -1849,6 +1608,7 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: return current_gnorm, clip_value, gnorm_scale +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 @@ -1959,100 +1719,33 @@ def gemv_4bit( transposed_B=False, state=None, ): - # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()") - if A.numel() != A.shape[-1]: - raise ValueError( - 'Dimensions of A are invalid. Must be a vector with the leading dimensions of "1", e.g. [1, 1, 2048]', - ) - - Bshape = state.shape - bout = Bshape[0] absmax = state.absmax if state.nested: - absmax = dequantize_blockwise(state.absmax, state.state2) - absmax += state.offset - - if out is None: - if len(A.shape) == 3: - out = torch.empty(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device) - else: - out = torch.empty(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) - - n = 1 - m = Bshape[0] - k = Bshape[1] - lda = Bshape[0] - ldc = Bshape[0] - ldb = (A.shape[-1] + 1) // 2 - is_on_gpu([B, A, out, absmax, state.code]) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - with _cuda_device_of(A): - if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: - if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) - elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) - elif A.dtype == torch.float32: - lib.cgemm_4bit_inference_naive_fp32( - m, - n, - k, - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(state.code), - get_ptr(out), - lda, - ldb, - ldc, - ct.c_int32(state.blocksize), - stream, - ) - else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + absmax = dequantize_blockwise(absmax, state.state2) + state.offset - else: - raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") + if out is not None: + torch.ops.bitsandbytes.gemv_4bit.out( + A, + B, + state.shape, + absmax, + state.code, + state.blocksize, + out=out, + ) + return out - return out + return torch.ops.bitsandbytes.gemv_4bit.default( + A, + B, + state.shape, + absmax, + state.code, + state.blocksize, + ) def igemm( @@ -2252,27 +1945,6 @@ def batched_igemm( return out -@deprecated( - "igemmlt is deprecated and will be removed in a future release. Please use int8_linear_matmul instead.", - category=FutureWarning, -) -def igemmlt( - A: torch.Tensor, - B: torch.Tensor, - SA: Tuple[torch.Size, str], - SB: Tuple[torch.Size, str], - out: Optional[torch.Tensor] = None, - Sout: Optional[Tuple[torch.Size, str]] = None, - dtype=torch.int32, -): - if SA is not None and SA[1] != "row": - raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`") - if SB is not None and SB[1] != "row": - raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`") - result = int8_linear_matmul(A, B, out=out, dtype=dtype) - return result, (result.shape, "row") - - def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): """Performs an 8-bit integer matrix multiplication. @@ -2292,88 +1964,11 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten Returns: `torch.Tensor`: The result of the operation. """ + if out is not None: + torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out) + return out - # - # To use the IMMA tensor core kernels without special Turing/Ampere layouts, - # cublasLt has some rules, namely: A must be transposed, B must not be transposed. - # The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major. - # This will typically be used with row-major tensors to efficiently - # calculate the linear layer with `C = B @ A.T` without any transformations. - # We will swap A and B in the API invocation, so that we get `C = A @ B.T`. - # - # Quick explanation: - # With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`. - # To get row-major output, `C.T = (A @ B.T).T = B @ A.T`. - # - A, B = B, A - - shapeA = A.shape - shapeB = B.shape - - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - assert A.ndim == 2, "Only two dimensional matrices are supported for argument B" - assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A" - assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}" - assert out is None or out.dtype == dtype - - shapeC = (*shapeB[:-1], shapeA[0]) - - k, m = shapeA - n = prod(shapeB[:-1]) - lda = shapeA[-1] # Weights (outputs, inputs) - ldb = shapeB[-1] # Activations (batch, tokens, inputs) - ldc = shapeC[-1] # Output (batch, tokens, outputs) - - assert ( - lda == ldb - ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}" - - # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4. - # We'll fall back to a slower fp32 calculation in this circumstance. - # Fortunately, this should not be very common. - if lda % 4 != 0: - result = torch.matmul(B.float(), A.float().t()).to(torch.int32) - if out is not None: - result = out.copy_(result) - return result - - if out is None: - out = torch.empty(shapeC, device=A.device, dtype=dtype) - - is_on_gpu([A, B, out]) - - with _cuda_device_of(A): - ctx = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - ptrRowScale = None - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - lda = ct.c_int32(lda) - ldb = ct.c_int32(ldb) - ldc = ct.c_int32(ldc) - stream = _get_tensor_stream(A) - - if dtype == torch.int32: - has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - else: - has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream) - - if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` - raise NotImplementedError("int8_linear_matmul not implemented!") - - if has_error: - raise RuntimeError( - f"cublasLt ran into an error!\n" - f"\t{shapeA=}, {shapeB=}, {shapeC=}\n" - f"\t{(lda, ldb, ldc)=}\n" - f"\t{(m, n, k)=}" - ) - - return out + return torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) def int8_mm_dequant( @@ -2395,47 +1990,16 @@ def int8_mm_dequant( Returns: `torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`. """ + result = torch.ops.bitsandbytes.int8_mm_dequant.default(A, row_stats, col_stats, dtype=torch.float16, bias=bias) - assert A.dtype == torch.int32 - - if bias is not None: - assert bias.dtype == torch.float16 - - if out is None: - out = torch.empty_like(A, dtype=torch.float16) - - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrBias = get_ptr(bias) - numRows = ct.c_int32(prod(A.shape[:-1])) - numCols = ct.c_int32(A.shape[-1]) - - is_on_gpu([A, row_stats, col_stats, out, bias]) - - with _cuda_device_of(A): - lib.cdequant_mm_int32_fp16( - ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) - ) + # TODO(matthewdouglas): Deprecate out kwarg + if out is not None: + return out.copy_(result) - return out - - -@deprecated("mm_dequant is deprecated. Please use int8_mm_dequant() instead.", category=FutureWarning) -def mm_dequant( - A: torch.Tensor, - quant_state: Optional[Tuple[torch.Size, str]], # Not used - row_stats: torch.Tensor, - col_stats: torch.Tensor, - out: Optional[torch.Tensor] = None, - new_row_stats=None, # Not used - new_col_stats=None, # Not used - bias: Optional[torch.Tensor] = None, -): - return int8_mm_dequant(A, row_stats, col_stats, out, bias) + return result +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def get_colrow_absmax( A: torch.Tensor, row_stats: Optional[torch.Tensor] = None, @@ -2493,6 +2057,7 @@ def get_colrow_absmax( return row_stats, col_stats, outlier_mask +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def get_row_absmax(A: torch.Tensor, threshold=0.0): """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. @@ -2611,72 +2176,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -@deprecated("This function is deprecated. Please use `int8_double_quant` instead.", category=FutureWarning) -def double_quant( - A: torch.Tensor, - col_stats: Optional[torch.Tensor] = None, - row_stats: Optional[torch.Tensor] = None, - out_col: Optional[torch.Tensor] = None, - out_row: Optional[torch.Tensor] = None, - threshold=0.0, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[COOSparseTensor]]: - """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm. - - The statistics are determined both row-wise and column-wise (transposed). - - For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339). - - - This function exists for backwards compatibility only. It is advised to use [`int8_double_quant`] instead. - The difference is that this function will return a [`COOSparseTensor`] for outliers instead of a column index. - - - Args: - A (`torch.Tensor` with dtype `torch.float16`): The input matrix. - col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales. - row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales. - out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data. - out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data. - threshold (`float`, *optional*): - An optional threshold for sparse decomposition of outlier features. - - No outliers are held back when 0.0. Defaults to 0.0. - - Returns: - `Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics. - - `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data. - - `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data. - - `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales. - - `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales. - - `COOSparseTensor`, *optional*: A structure representing the outlier values from the input tensor. - """ - - coo_tensor = None - quant_row, quant_col, row_stats, col_stats, outlier_cols = int8_double_quant( - A, - col_stats, - row_stats, - out_col, - out_row, - threshold=threshold, - ) - - if threshold > 0.0 and outlier_cols is not None: - # Build a COO tensor including all of the outlier columns. - outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32) - outliers = A[:, outlier_cols] - coo_tensor = COOSparseTensor( - A.shape[0], - A.shape[1], - outliers.numel(), - outlier_rows.repeat_interleave(outliers.size(1)), - outlier_cols.repeat(outliers.size(0)).int(), - outliers, - ) - - return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor - - def int8_double_quant( A: torch.Tensor, col_stats: Optional[torch.Tensor] = None, @@ -2716,23 +2215,16 @@ def int8_double_quant( - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. """ - # TODO: Optimize/write CUDA kernel for this? - - # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold) - - # PyTorch impl for colwise - _, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) - if threshold > 0.0 and outlier_mask is not None: - A = A.masked_fill(outlier_mask, 0.0) - quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8) - - if out_row is not None: - quant_row = out_row.copy_(quant_row) + if row_stats is not None: + raise ValueError("row_stats must be None. int8_double_quant() does not support pre-allocated row_stats.") + if col_stats is not None: + raise ValueError("col_stats must be None. int8_double_quant() does not support pre-allocated col_stats.") if out_col is not None: - quant_col = out_col.copy_(quant_col) + raise ValueError("out_col must be None. int8_double_quant() does not support pre-allocated out_col.") + if out_row is not None: + raise ValueError("out_row must be None. int8_double_quant() does not support pre-allocated out_row.") - return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols + return torch.ops.bitsandbytes.int8_double_quant.default(A, threshold=threshold) def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor): @@ -2746,7 +2238,7 @@ def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor): `torch.Tensor` with dtype `torch.float32`: The dequantized tensor. """ # To dequantize we divide by 127, or multiply by the reciprocal. - return A * stats.view(-1, 1) * 7.874015718698502e-3 + return torch.ops.bitsandbytes.int8_vectorwise_dequant.default(A, stats) def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): @@ -2767,94 +2259,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): - `torch.Tensor` with dtype `torch.float32`: The quantization scales. - `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features. """ - - assert A.dtype == torch.half - is_on_gpu([A]) - - rows = prod(A.shape[:-1]) - cols = A.shape[-1] - - row_stats = torch.empty(rows, device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) - - outlier_cols = None - - if threshold > 0.0: - # TODO we could improve perf of this - outliers = A.abs() >= threshold - - if outliers.any(): - outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - - with _cuda_device_of(A): - lib.cint8_vector_quant( - get_ptr(A), - get_ptr(out_row), - get_ptr(row_stats), - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - _get_tensor_stream(A), - ) - - # Zero out values from outlier columns across all rows. - # The kernel will handle this for outliers themselves, so we can optimize for rows=1. - if rows > 1 and outlier_cols is not None: - out_row[:, outlier_cols] = 0 - - return out_row, row_stats, outlier_cols - - -@deprecated( - "The layout transformation operations will be removed in a future release. Please use row-major layout only.", - category=FutureWarning, -) -def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): - prev_device = pre_call(A.device) - if state is None: - state = (A.shape, from_order) - else: - from_order = state[1] - if out is None: - out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: - new_state = (state[0], to_order) # (shape, order) - - shape = state[0] - if len(shape) == 2: - dim1 = ct.c_int32(shape[0]) - dim2 = ct.c_int32(shape[1]) - else: - dim1 = ct.c_int32(shape[0] * shape[1]) - dim2 = ct.c_int32(shape[2]) - - is_on_gpu([A, out]) - if to_order == "col32": - if transpose: - lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_turing": - if transpose: - lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_ampere": - if transpose: - lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "row": - if from_order == "col_turing": - lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == "col_ampere": - lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) - else: - raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") - - post_call(prev_device) - - return out, new_state + return torch.ops.bitsandbytes.int8_vectorwise_quant.default(A, threshold) def spmm_coo( @@ -3059,7 +2464,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): @deprecated( - "This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.", + "This function is deprecated and will be removed in a future release.", category=FutureWarning, ) def vectorwise_dequant(xq, max1, quant_type="vector"): @@ -3071,7 +2476,7 @@ def vectorwise_dequant(xq, max1, quant_type="vector"): @deprecated( - "This function is deprecated and will be removed in a future release. Consider using `int8_mm_dequant` instead.", + "This function is deprecated and will be removed in a future release.", category=FutureWarning, ) def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): @@ -3131,51 +2536,3 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): return x.to(dtype) else: return None - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): - offset = B.float().t().sum(0) * (SA[0] + SA[1]) - x = xq.float() - if len(xq.shape) == 2 and len(SB.shape) == 3: - SB = SB.squeeze(0) - if len(SB.shape) == 2: - x *= SB.t() / 127 - else: - x *= SB / 127 - x *= SA[1] / 127 - x += offset - return x.to(dtype) - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def extract_outliers(A, SA, idx): - shapeA = SA[0] - formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] - assert A.device.type == "cuda" - - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) - - idx_size = ct.c_int32(idx.numel()) - rows = ct.c_int32(shapeA[0]) - cols = ct.c_int32(shapeA[1]) - ptrA = get_ptr(A) - ptrIdx = get_ptr(idx) - ptrOut = get_ptr(out) - - prev_device = pre_call(A.device) - if formatA == "col_turing": - lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == "col_ampere": - lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - post_call(prev_device) - - return out - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def pipeline_test(A, batch_size): - out = torch.zeros_like(A) - lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) - return out diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index e63cd8db9..f4d838d48 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -11,7 +11,6 @@ import torch.nn.functional as F import bitsandbytes as bnb -from bitsandbytes.autograd._functions import get_tile_inds, undo_layout from bitsandbytes.functional import QuantState from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( @@ -654,8 +653,7 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format] if weight_format != "row": - tile_indices = get_tile_inds(weight_format, weight.device) - state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) + raise ValueError(f"Only 'row' weight format is supported, got {weight_format}") class Embedding8bit(nn.Embedding): diff --git a/csrc/kernels.cu b/csrc/kernels.cu index fdf1d02c0..22ee756d9 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2205,333 +2205,6 @@ __global__ void kdequant_mm_int32_fp16( } } -template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) -{ - - // 0. Load data into 32*32 shared memory tiles - // 1. transpose / reorder in shared memory - // 2. store - - // COL32 FORMAT: - // rows*32 tiles - - // TURING FORMAT: - // 8*32 tiles with 4*4 subtiles - // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) - // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero - // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) - // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column - // index increases by 32 - - // AMPERE FORMAT: - // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: - // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... - // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] - - - // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values - // As such we need: - // at least 32*4 shared memory tiles for col32; preferably 32*32 - // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 - // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 - // for efficient loading of row major we need to load 128 elements and repeat this 32 items - // this would imply a 32x128 shared memory tile -> 4kb - // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb - // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy - // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough - // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM - // - // to make the shared memory work with that occupancy we might need to union the block loads/stores - - // each block loads TILE_COLs columns and TILE_ROW rows - // after reading a tile the row counter increase by TILE_ROWS - // the col counter reset after reading TILE_COL elements - const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; - // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached - const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; - const int base_idx = (base_row*cols) + base_col; - - // we load 128 bytes per warp with - // 32 rows for transposes that fill col32 types - // so that we can have contiguous stores - __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; - char local_data[ITEMS_PER_THREAD]; - typedef cub::BlockExchange BlockExchange; - - // we load row after row from the base_position - // Load data row by row - int warps = blockDim.x/32; - int warp_id = threadIdx.x/32; - int warp_lane = threadIdx.x % 32; - int offset = 0; - - int smem_row = 0; - // each warp loads one row of 128 bytes - for(int row = warp_id; row < TILE_ROWS; row+=warps) - { - int i = base_idx + (row*cols); - // we load up to 128 bytes/items per load - int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; - - // 0. Load data into 32*32 shared memory tiles - if(base_row + row < rows) - { - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - int col_idx = warp_lane+(j*32); - if(col_idx < valid_items) - local_data[j] = A[i+col_idx]; - else - local_data[j] = 0; - } - } - else - { - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - local_data[j] = 0; - } - - if(TRANSPOSE) - { - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - int local_col = (32*j)+warp_lane; - //int local_row = row; - // store as 256x32 - smem_data[(local_col*33) + row] = local_data[j]; - } - } - else - { - // treat smem as 32x256, that is 32 rows and 256 columns - #pragma unroll ITEMS_PER_THREAD - for(int j = 0; j < ITEMS_PER_THREAD; j++) - smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; - } - - - - smem_row += warps; - - // 1. transpose / reorder in shared memory - if(smem_row % 32 == 0) - { - smem_row = 0; - __syncthreads(); - - for(int subrow = warp_id; subrow < 32; subrow+=warps) - { - for(int j = 0; j < ITEMS_PER_THREAD; j++) - { - - switch(FORMAT) - { - case COL32: - if(TRANSPOSE) - { - // data lies in shared memory in the following way: - // row0 [col0 col1 ... col31] - // row1 [col0 col1 ... col31] - // ... - // - // As such we read consecutive entries with 256 threads (8rows x 32 columns) - // as j increase, the row increase by a factor of 8 - // We load 8 rows per subrow loop, and subrow increase by 8 per loop - // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 - const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j - const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) - //const int local_row = warp_id; // each warp_id is one row - //const int block_row = base_col; // block offset for row - //const int local_col = warp_lane - //const int global_col = base_row; // block offset for col - if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) - { - // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem - char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; - - // each 32 columns we have new tile - // each tile has size outRows*32 and base_row is done in increments of 32 - offset = base_row*outRows; - out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; - } - } - else - { - if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) - { - offset = (base_col/32)*(32*rows); - char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; - out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; - } - } - break; - case COL_TURING: - // TURING FORMAT: - // 8*32 tiles with 4*4 subtiles - // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) - // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero - // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) - // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column - // index increases by 32 - // - // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] - if(TRANSPOSE) - { - const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j - const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) - //const int local_row = warp_id; // each warp_id is one row - //const int block_row = base_col; // block offset for row - //const int local_col = warp_lane - //const int global_col = base_row; // block offset for col - if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) - { - // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem - char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; - - // each 32 columns we have new tile - // each tile has size 8*32 = 256 elements offset - // for each row offset of 8 we increaes the tile first - // after all rows are exhausted, we increase the col - int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows - - // we increase by row_tile_column every 32 columns - // base_row increase in increments of 32 - //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements - //int col_offset = (base_row/32)*row_tile_column; - // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 - // 256*outRows/8*base_row/32 = outRows*base_row - int col_offset = outRows*base_row; - - offset = row_offset+col_offset; - - // since we process even number of rows with each j (8) and with each subrow (8j) we can determine - // odd or even rows with the warp_id (each warp processes one row) - // the col is warp_lane (max 32 columns per row) and the row warp_id - if(warp_id % 2 == 1) - // odd - offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); - else - // even - offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); - - out[offset] = data; - } - } - else - { - if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) - { - char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; - // set offset designates the tile offset among the 8*32 tiles - // we first increase rows and then columns. Since we load 128 columns at once - // we increase the offset by outRows*32 every 32 columns - // additionally, we increase the offset by 8*32=256 every 8 rows - offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) - // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd - // each of these has 32 values in total for 32*4 = 128 as offset if odd - // every set of 4 columns increases the total offset by 16 - // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 - // this happens every 8 rows anew (subrow % 8) - // one writes 4 columns at once that is (col % 4) for the particular index in the subtile - int subcol = warp_lane; - - // add local offset (4x4 sub-tile) - if(subrow % 2 == 1) - // odd - offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); - else - // even - offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); - - out[offset] = data; - } - } - break; - case COL_AMPERE: - // AMPERE FORMAT: - // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: - // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... - // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] - if(TRANSPOSE) - { - const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j - const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) - //const int local_row = warp_id; // each warp_id is one row - //const int block_row = base_col; // block offset for row - //const int local_col = warp_lane - //const int global_col = base_row; // block offset for col - if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) - { - // each row has 32 columns and is offset by 1 to prevent bank conflict during storage into smem - char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; - - // each 32 columns we have new tile - // each tile has size 32*32 = 1024 elements offset - // for each row offset of 32 we increaes the tile first - // after all rows are exhausted, we increase the col - int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows - - // we increase by row_tile_column every 32 columns - // base_row increase in increments of 32 - //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements - //int col_offset = (base_row/32)*row_tile_column; - // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 - // 1024*outRows/32*base_row/32 = outRows*base_row - int col_offset = outRows*base_row; - - offset = row_offset+col_offset; - - - // same as in the non-transpose case (see below) - // the difference is that now rows = cols - // in this case warp_id = subrow - - // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... - // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc - // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row - // every 2 rows, the offset increases by two [0, 1, 8, 9...] - // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] - int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset - int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); - - // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane - out[offset + (ampere_row*32) + warp_lane] = data; - } - } - else - { - if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) - { - char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; - - // set offset designates the tile offset among the 32*32 tiles - // we first increase rows and then columns. Since we load 128 columns at once - // we increase the offset by outRows*32 every 32 columns - // additionally, we increase the offset by 32*32=1024 every 32 rows - offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) - - // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... - // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc - // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row - // every 2 rows, the offset increases by two [0, 1, 8, 9...] - // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] - int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); - - // global offset + row with 32 cols each + 32 cols per j + col_idx - out[offset + (local_row*32) + warp_lane] = data; - } - } - break; - } - } - } - } - } -} - #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 @@ -2679,69 +2352,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } } -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) -{ - int local_colidx = idx[blockIdx.x]; - - if(FORMAT==COL_TURING) - { - // TURING FORMAT: - // 8*32 tiles with 4*4 subtiles - // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) - // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero - // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) - // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column - // index increases by 32 - // columns are grouped in increments of 4, meaning that one has the following rows and columns - // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] - // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] - - // each thread reads 1 element = 1 row - for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) - { - int offset_per_col_tile = ((rowsA+7)/8)*32*8; - int tile_offset_rows = (row/8)*32*8; - int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; - int offset = 0; - int subtile_col_idx = local_colidx%32; - int subtile_row_idx = row % 8; - if(row % 2 == 1) - offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); - else - // even - offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); - - offset += tile_offset_rows + tile_offset_cols; - - char val = A[offset]; - - int out_idx = (row*idx_size) + blockIdx.x; - out[out_idx] = val; - } - } - else if(FORMAT == COL_AMPERE) - { - - for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) - { - // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element - // within each tile. - int offset_per_col_tile = ((rowsA+31)/32)*32*32; - int tile_offset_rows = (row/32)*32*32; - int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; - int subtile_col_idx = local_colidx%32; - int subtile_row_idx = row % 32; - // this magic is taken from the cublasLt doc (search for COL32) - int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; - offset += tile_offset_cols + tile_offset_rows; - - char val = A[offset]; - int out_idx = (row*idx_size) + blockIdx.x; - out[out_idx] = val; - } - } -} - #define WARPS 3 template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) { @@ -3376,9 +2986,6 @@ template __global__ void kgemm_4bit_inference_naive(int M, int N, template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); - template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); @@ -3386,13 +2993,6 @@ template __global__ void kspmm_coo_very_sparse_naive(int *max template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); - template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 18017c4d2..a701481d3 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -121,8 +121,6 @@ template __global__ void kInt8Vector template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); - template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index e6c2bb443..775984553 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -374,48 +374,6 @@ template int get_leading_dim(int dim1, int dim2) } } -template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) -{ - cublasLtOrder_t orderA = get_order(); - cublasLtOrder_t orderOut = get_order(); - int ldA = get_leading_dim(dim1, dim2); - int ldOut = get_leading_dim(dim1, dim2); - - cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; - cublasLtMatrixTransformDesc_t A2Out_desc = NULL; - cublasOperation_t opTranspose = CUBLAS_OP_T; - float transformAlpha = 1.0f, transformBeta = 0.0f; - - - if(DTYPE == 8) - { - checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA)); - checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut)); - } - else if(DTYPE == 32) - { - checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA)); - checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut)); - } - else - { - printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); - } - - checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); - checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); - - checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F)); - - if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } - - checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); - - if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); - if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); - if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); -} - template int igemmlt( cublasLtHandle_t ltHandle, int m, int n, int k, @@ -542,50 +500,6 @@ void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void transformRowToFormat(char * A, char *out, int rows, int cols) -{ - int threads = 256; - int items_per_thread = 8; - // we load 128 column values per warp - int tile_cols = 32*items_per_thread; - int tile_rows = 32; - int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); - int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int row_tiles = (tiledRows/tile_rows); - int col_tiles = (tiledCols/tile_cols); - row_tiles = row_tiles > 0 ? row_tiles : 1; - col_tiles = col_tiles > 0 ? col_tiles : 1; - int num_blocks = row_tiles * col_tiles; - - int outCols = fill_up_to_nearest_multiple(cols, 32); - int outRows = fill_up_to_nearest_multiple(rows, 32); - if(FORMAT == COL_TURING) - { - if(TRANSPOSE) - outRows = fill_up_to_nearest_multiple(cols, 8); - else - outRows = fill_up_to_nearest_multiple(rows, 8); - } - else if(FORMAT == COL_AMPERE) - { - if(TRANSPOSE) - outRows = fill_up_to_nearest_multiple(cols, 32); - else - outRows = fill_up_to_nearest_multiple(rows, 32); - } - else - { - if(TRANSPOSE) - { - outCols = fill_up_to_nearest_multiple(rows, 32); - outRows = cols; - } - } - - kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { cusparseSpMatDescr_t descA; @@ -643,32 +557,6 @@ template void spmm_coo_very_sparse_naive(int *max_count, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - -template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) -{ - int threads = 256; - // we load 128 column values per warp - int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); - int tiledRows = 0; - - int num_blocks = idx_size; - - if(FORMAT == COL_TURING) - { - tiledRows = fill_up_to_nearest_multiple(rows, 8); - } - else if(FORMAT == COL_AMPERE) - { - tiledRows = fill_up_to_nearest_multiple(rows, 32); - } - - kExtractOutliers<<>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - - - - template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) { @@ -722,8 +610,6 @@ template void gemm_4bit_inference_naive(int m, int n, int k, float * //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); -template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); -template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); @@ -732,13 +618,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); -template void transformRowToFormat(char * A, char *out, int rows, int cols); -template void transformRowToFormat(char * A, char *out, int rows, int cols); -template void transformRowToFormat(char * A, char *out, int rows, int cols); -template void transformRowToFormat(char * A, char *out, int rows, int cols); -template void transformRowToFormat(char * A, char *out, int rows, int cols); -template void transformRowToFormat(char * A, char *out, int rows, int cols); - template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); @@ -840,15 +719,6 @@ MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); -template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); - template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); template int get_leading_dim(int dim1, int dim2); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 1170237e1..48a6a3c74 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -173,20 +173,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); -template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream); void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream); -template void transformRowToFormat(char * A, char *out, int rows, int cols); - void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); -template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); - void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 0ced0394c..56bec82e8 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -149,32 +149,6 @@ void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } - -#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ -void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ -{ \ - transform(ltHandle, A, out, dim1, dim2); \ -} \ - -MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8); -MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8); -MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8); -MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32); -MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8); -MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8); -MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8); -MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32); - -void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } -void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } -void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } -void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } -void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } -void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } - -void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } -void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } - int igemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } @@ -317,22 +291,6 @@ extern "C" int cigemmlt_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) { return igemmlt_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream); } - - #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ - void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ - { \ - transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \ - } \ - - MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8) - MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8) - MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8) - MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32) - MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8) - MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8) - MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) - MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) - void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half* bias, int numRows, int numCols, cudaStream_t stream) { dequant_mm_int32_fp16(A, rowStats, colStats, out, bias, numRows, numCols, stream); } void cget_row_stats(half *A, float *rowStats, float threshold, int rows, int cols, cudaStream_t stream) { @@ -342,24 +300,6 @@ extern "C" int8VectorQuant(A, out, rowStats, threshold, rows, cols, stream); } - void ctransform_row2col32(char * A, char *out, int rows, int cols) - { transform_row2col32(A, out, rows, cols); } - - void ctransform_row2col32T(char * A, char *out, int rows, int cols) - { transform_row2col32T(A, out, rows, cols); } - - void ctransform_row2turing(char * A, char *out, int rows, int cols) - { transform_row2turing(A, out, rows, cols); } - - void ctransform_row2turingT(char * A, char *out, int rows, int cols) - { transform_row2turingT(A, out, rows, cols); } - - void ctransform_row2ampere(char * A, char *out, int rows, int cols) - { transform_row2ampere(A, out, rows, cols); } - - void ctransform_row2ampereT(char * A, char *out, int rows, int cols) - { transform_row2ampereT(A, out, rows, cols); } - void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } @@ -369,9 +309,6 @@ extern "C" void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } - void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } - void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } - //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 1eea3247b..6fa8c3b29 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -174,7 +174,7 @@ export BNB_CUDA_VERSION=126 export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/YOUR_USERNAME/local/cuda-12.6 ``` -3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 11.7) and a different bitsandbytes library is loaded. +3. Now when you launch bitsandbytes with these environment variables, the PyTorch CUDA version is overridden by the new CUDA version (in this example, version 12.6) and a different bitsandbytes library is loaded. ## Multi-backend Support (Alpha Release)[[multi-backend]] diff --git a/pyproject.toml b/pyproject.toml index 6e5c6dde3..f4ae66a8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence" ] dependencies = [ - "torch>=2.0,<3", + "torch>=2.2,<3", "numpy>=1.17" ] diff --git a/tests/helpers.py b/tests/helpers.py index 02cb881a3..de11f4f66 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -36,6 +36,8 @@ def format_with_label(label: str, value: Any) -> str: formatted = "T" if value else "F" elif isinstance(value, (list, tuple)) and all(isinstance(v, bool) for v in value): formatted = "".join("T" if b else "F" for b in value) + elif isinstance(value, torch.dtype): + formatted = describe_dtype(value) else: formatted = str(value) return f"{label}={formatted}" diff --git a/tests/test_autograd.py b/tests/test_autograd.py index ae2529542..4b93ebcbe 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,12 +1,9 @@ -from typing import Tuple - import pytest import torch import bitsandbytes as bnb from tests.helpers import ( BOOLEAN_TRIPLES, - BOOLEAN_TUPLES, TRUE_FALSE, describe_dtype, get_test_dims, @@ -16,189 +13,6 @@ TRANSPOSE_VALS = [(False, True), (False, False)] -@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(32, 96, n=1), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) -@pytest.mark.parametrize( - "funcs", - [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)], - ids=["func=bmm", "func=matmul"], -) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) -@pytest.mark.parametrize("req_grad", BOOLEAN_TUPLES, ids=id_formatter("req_grad")) -@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) -@pytest.mark.deprecated -def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool], transpose: Tuple[bool, bool]): - if dim2 > 0: - dim2 = dim2 - (dim2 % 16) - dim3 = dim3 - (dim3 % 16) - dim4 = dim4 - (dim4 % 16) - for i in range(25): - # normal multiply - if funcs[0] in [torch.mm, torch.matmul]: - dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) - dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) - B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) - target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]) - torch.nn.init.xavier_uniform_(B) - - if not transpose[0] and not transpose[1]: - out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B) - elif not transpose[0] and transpose[1]: - out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B.t()) - elif transpose[0] and not transpose[1]: - out_torch = funcs[0](A.t(), B) - out_bnb = funcs[1](A.t(), B) - elif transpose[0] and transpose[1]: - out_torch = funcs[0](A.t(), B.t()) - out_bnb = funcs[1](A.t(), B.t()) - - n = out_bnb.numel() - idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() < n * 0.0175 - idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx == 0).sum().item() < n * 0.001 - - if any(req_grad): - out_bnb.data.copy_(out_torch) - torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() - loss_bnb.backward() - gradA1 = A.grad - gradB1 = B.grad - A.grad = None - B.grad = None - - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() - loss_torch.backward() - gradA2 = A.grad - gradB2 = B.grad - A.grad = None - B.grad = None - - if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) - if req_grad[1]: - n = gradB1.numel() - idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx == 0).sum().item() < n * 0.1 - idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) - - # batched matrix multiply - if funcs[0] in [torch.bmm, torch.matmul]: - A = torch.randn( - size=(dim1, dim2, dim3), - device="cuda", - requires_grad=req_grad[0], - ) - B = torch.randn( - size=(dim1, dim3, dim4), - device="cuda", - requires_grad=req_grad[1], - ) - target = torch.randn( - size=(dim1, dim2, dim4), - device="cuda", - requires_grad=req_grad[1], - ) - torch.nn.init.xavier_uniform_(B) - - out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B) - - n = out_bnb.numel() - idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_close(out_bnb, out_torch, atol=0.027, rtol=0.2) - - if any(req_grad): - out_bnb.data.copy_(out_torch) - torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() - loss_bnb.backward() - gradA1 = A.grad - gradB1 = B.grad - A.grad = None - B.grad = None - - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() - loss_torch.backward() - gradA2 = A.grad - gradB2 = B.grad - A.grad = None - B.grad = None - - if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) - if req_grad[1]: - n = gradB1.numel() - idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx == 0).sum().item() < n * 0.1 - idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx == 0).sum().item() < n * 0.02 - - if funcs[0] in [torch.matmul]: - dim1 = dim1 - (dim1 % 16) - A = torch.randn( - size=(dim1, dim2, dim3), - device="cuda", - requires_grad=req_grad[0], - ) - dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) - B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) - target = torch.randn( - size=(dim1, dim2, dim4), - device="cuda", - requires_grad=req_grad[1], - ) - torch.nn.init.xavier_uniform_(B) - - if transpose[1]: - out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B.t()) - else: - out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B) - - n = out_bnb.numel() - idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() < n * 0.0175 - idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx == 0).sum().item() < n * 0.001 - - if any(req_grad): - out_bnb.data.copy_(out_torch) - torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() - loss_bnb.backward() - gradA1 = A.grad - gradB1 = B.grad - A.grad = None - B.grad = None - - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() - loss_torch.backward() - gradA2 = A.grad - gradB2 = B.grad - A.grad = None - B.grad = None - - if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) - if req_grad[1]: - n = gradB1.numel() - idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx == 0).sum().item() < n * 0.1 - idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx == 0).sum().item() < n * 0.02 - - @pytest.mark.parametrize("dim1", [40], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py new file mode 100644 index 000000000..9872cdfca --- /dev/null +++ b/tests/test_deprecated.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest +from scipy.stats import norm +import torch + +from bitsandbytes import functional as F + + +@pytest.mark.deprecated +def test_kbit_quantile_estimation(): + for i in range(100): + data = torch.randn(1024, 1024, device="cuda") + for bits in range(2, 9): + p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits) + val1 = torch.Tensor(norm.ppf(p)).cuda() + val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) + err = torch.abs(val1 - val2).mean() + assert err < 0.038 + + for i in range(100): + data = torch.randn(1024, 1024, device="cuda") + for bits in range(2, 4): + total_values = 2**bits - 1 + p = np.linspace(0, 1, 2 * total_values + 1) + idx = np.arange(1, 2 * total_values + 1, 2) + p = p[idx] + offset = 1 / (2 * total_values) + p = np.linspace(offset, 1 - offset, total_values) + val1 = torch.Tensor(norm.ppf(p)).cuda() + val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1) + err = torch.abs(val1 - val2).mean() + assert err < 0.035 + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) +@pytest.mark.deprecated +def test_estimate_quantiles(dtype): + A = torch.rand(1024, 1024, device="cuda") + A = A.to(dtype) + code = F.estimate_quantiles(A) + + percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) + torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2) + + A = torch.randn(1024, 1024, device="cuda") + A = A.to(dtype) + code = F.estimate_quantiles(A) + + quantiles = torch.quantile(A.float(), percs) + diff = torch.abs(code - quantiles) + assert (diff > 5e-02).sum().item() == 0 + + +@pytest.mark.deprecated +def test_quantile_quantization(): + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1 - A2).mean().item() + assert diff < 0.0075 + + A1 = torch.rand(1024, 1024, device="cuda") + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1 - A2).mean().item() + torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0) + assert diff < 0.001 + + +@pytest.mark.deprecated +def test_dynamic_quantization(): + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diff.mean().item() < 0.0135 + print(sum(diffs) / len(diffs)) + print(sum(reldiffs) / len(reldiffs)) + + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1 - A2).mean().item() + torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + assert diff < 0.004 + + +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) +@pytest.mark.deprecated +def test_percentile_clipping(gtype): + gnorm_vec1 = torch.zeros(100, device="cuda") + gnorm_vec2 = torch.zeros(100, device="cuda") + n = 4 + step = 0 + percentile = 5 + for i in range(20): + step += 1 + g = torch.randn(n, n, dtype=gtype, device="cuda") + gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) + assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 + + gnorm2 = torch.norm(g.float()) + if step == 1: + gnorm_vec1[:] = gnorm2 + else: + gnorm_vec1[step % 100] = gnorm2 + + vals, idx = torch.sort(gnorm_vec1) + clip1 = vals[percentile] + + torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) + torch.testing.assert_close(clip1, clip2) + torch.testing.assert_close(gnorm1, gnorm2) diff --git a/tests/test_functional.py b/tests/test_functional.py index c8ac20896..95d5cd6dc 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,4 +1,3 @@ -from itertools import product import math import random import time @@ -6,7 +5,6 @@ import einops import numpy as np import pytest -from scipy.stats import norm import torch import bitsandbytes as bnb @@ -88,77 +86,194 @@ def reset(self): print("Resetting benchmark data") -def setup(): - pass - - -def teardown(): - pass - - -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) -def test_estimate_quantiles(dtype): - A = torch.rand(1024, 1024, device="cuda") - A = A.to(dtype) - code = F.estimate_quantiles(A) - - percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) - torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2) - - A = torch.randn(1024, 1024, device="cuda") - A = A.to(dtype) - code = F.estimate_quantiles(A) - - quantiles = torch.quantile(A.float(), percs) - diff = torch.abs(code - quantiles) - assert (diff > 5e-02).sum().item() == 0 +class Test8BitBlockwiseQuantizeFunctional: + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) + @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) + @pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) + @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) + def test_dynamic_blockwise_quantization(self, dtype, nested, blocksize, signed): + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2).float() + reldiff = diff / torch.abs(A1.float() + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + abserr = sum(diffs) / len(diffs) + relerr = sum(reldiffs) / len(reldiffs) + # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) + # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) + assert abserr < 0.011 + assert relerr < 0.018 + assert A2.dtype == dtype + + diffs = [] + code = F.create_dynamic_map(signed=signed) + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype) + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2).float() + reldiff = diff / torch.abs(A1.float() + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs) / len(diffs) + relerr = sum(reldiffs) / len(reldiffs) + if signed: + assert abserr < 0.0035 + assert relerr < 0.015 + else: + assert abserr < 0.00175 + assert relerr < 0.012 + assert A2.dtype == dtype + # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + + def test_blockwise_cpu_large(self): + diffs = [] + reldiffs = [] + batch = 128 + seq = 128 + for hidden in [128]: # , 14336]: + for blocksize in [4096, 16384]: + for i in range(2): + A1 = torch.randn(batch, seq, hidden, device="cpu") + t0 = time.time() + C, S = F.quantize_blockwise(A1, blocksize=blocksize) + A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) + print(time.time() - t0) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diffs[-1] < 0.011 + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs)) + + @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) + @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) + def test_few_bit_quant(self, bits, method): + abserrs = [] + relerrs = [] + code = None + if method == "linear": + code = F.create_linear_map(True, total_bits=bits).cuda() + elif method == "fp8": + ebits = math.ceil(bits / 2) + pbits = bits - ebits - 1 + code = F.create_fp8_map(True, ebits, pbits, bits).cuda() + elif method == "dynamic": + code = F.create_dynamic_map(True, bits - 0, bits).cuda() + elif method == "quantile": + values = torch.randn(2048, 2048, device="cuda") + code = F.create_quantile_map(values, bits).cuda() + # for some data types we have no zero + # for some data types we have one zero + # for some data types we have two zeros + assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}" + # print(method, (code==0).sum()) + assert code.numel() == 256 + for i in range(10): + values = torch.randn(1, 32, device="cuda") + values /= values.abs().max() + # values[values.abs() < 1e-6] += 1e-5 + + q1 = [] + v1 = [] + for v in values[0]: + idx = torch.abs(v - code).argmin() + q1.append(idx.item()) + v1.append(code[idx].item()) + + q1 = torch.Tensor(q1).cuda() + v1 = torch.Tensor(v1).cuda() + + q2, S2 = F.quantize_blockwise(values, code=code) + v2 = F.dequantize_blockwise(q2, S2) + + idx = torch.isclose(q1.int(), q2.int()) + err2 = torch.abs(v2 - values) + abserrs.append(err2.mean().item()) + relerrs.append((err2 / (1e-10 + values).abs()).mean().item()) + if idx.sum(): + # some weird cases + err1 = torch.abs(v1 - values).mean() + # assert err2.mean() <= err1 + else: + torch.testing.assert_close(q1, q2) + + def test_fp8_quant(self): + for e_bits in range(1, 7): + p_bits = 7 - e_bits + code = F.create_fp8_map(True, e_bits, p_bits).cuda() + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + # assert diff < 0.0075 + # print(sum(abserr)/len(abserr)) + # print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + # assert diff < 0.0075 + # print(sum(abserr)/len(abserr)) + # print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + # assert diff < 0.0075 + # print(3, sum(abserr)/len(abserr)) + # print(3, sum(relerr)/len(relerr)) + + @pytest.mark.benchmark + def test_bench_dequantization(self): + a = torch.rand(1024, 1024, device="cuda").half() + code = F.create_fp8_map(True, 3, 0, 4).cuda() + qa, SA = F.quantize_blockwise(a, code=code) + print(qa.max()) + + max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000 + # print(max_theoretical_mu) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + qa, SA = F.quantize_blockwise(a) + torch.cuda.synchronize() + # print((time.time()-t0)/1e6) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) -@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) -@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) -@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) -def test_dynamic_blockwise_quantization(dtype, nested, blocksize, signed): - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) - C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1 - A2).float() - reldiff = diff / torch.abs(A1.float() + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs) / len(diffs) - relerr = sum(reldiffs) / len(reldiffs) - # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(diffs)/len(diffs)) - # print('nested=', nested, 'randn', blocksize, 'dtype', dtype, sum(reldiffs)/len(reldiffs)) - assert abserr < 0.011 - assert relerr < 0.018 - assert A2.dtype == dtype - - diffs = [] - code = F.create_dynamic_map(signed=signed) - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda", dtype=dtype) - C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested, code=code) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1 - A2).float() - reldiff = diff / torch.abs(A1.float() + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - # torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs) / len(diffs) - relerr = sum(reldiffs) / len(reldiffs) - if signed: - assert abserr < 0.0035 - assert relerr < 0.015 - else: - assert abserr < 0.00175 - assert relerr < 0.012 - assert A2.dtype == dtype - # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) - # print('signed=', signed, 'nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) +def test_stable_embedding(): + layer = bnb.nn.StableEmbedding(1024, 1024) + layer.reset_parameters() def quant(x): @@ -198,11 +313,6 @@ def quant_multi_chunk(x, dim, chunk_size=32): return max1, x.to(torch.int8) -def quant_minmax(A): - minA = A.min() - maxA = A.max() - - def mean(xx): return sum(xx) / float(len(xx)) @@ -219,531 +329,617 @@ def mean(xx): } -@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) -@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys()) -@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched")) -def test_approx_igemm(dim1, dim2, quant_methods, batched): - dim1 = dim1 - (dim1 % 32) - dim2 = dim2 - (dim2 % 32) - errors = [] - relerrors = [] - # print("") - for i in range(5): - if batched: - A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") - B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda") - maxA, Ac = quant_methods[0](A, 2) - maxB, Bc = quant_methods[1](B, 1) - else: - A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda") - B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") - maxA, Ac = quant_methods[0](A, 1) - maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) - if batched: - out2 = torch.bmm(A, B) - C = torch.bmm(Ac.float(), Bc.float()) - else: - out2 = torch.mm(A, B) - C = F.igemm(Ac, Bc) - out = quant_methods[4](maxA, maxB, C) - std = out2.std() - out /= std - out2 /= std - err = torch.abs(out - out2) - relerr = err / torch.abs(out2) - errors.append(err.mean().item()) - relerrors.append(relerr.mean().item()) - # print(mean(errors)) - # print(mean(relerrors)) - - -def test_stable_embedding(): - layer = bnb.nn.StableEmbedding(1024, 1024) - layer.reset_parameters() - - -@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim")) -@pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim")) -@pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim")) -@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) -def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): - hidden_dim = hidden_dim - (hidden_dim % 32) - batch_dim = batch_dim - (batch_dim % 16) - seq_dim = seq_dim - (seq_dim % 16) - for i in range(k): - shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) - shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) - if not transpose[0] and not transpose[1]: - out2 = torch.matmul(A.float(), B.float()) - out = F.igemm(A, B) - elif not transpose[0] and transpose[1]: - out2 = torch.matmul(A.float(), B.t().float()) - out = F.igemm(A, B.t()) - elif transpose[0] and not transpose[1]: - out2 = torch.matmul(A.t().float(), B.float()) - out = F.igemm(A.t(), B) - elif transpose[0] and transpose[1]: - out2 = torch.matmul(A.t().float(), B.t().float()) - out = F.igemm(A.t(), B.t()) - - torch.testing.assert_close(out.float(), out2) - - for i in range(k): - shapeA = (batch_dim, seq_dim, hidden_dim) - shapeB = (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) - if not transpose[0] and not transpose[1]: - out2 = torch.matmul(A.float(), B.float()) - out = F.igemm(A, B) - elif not transpose[0] and transpose[1]: - out2 = torch.matmul(A.float(), B.t().float()) - out = F.igemm(A, B.t()) - - torch.testing.assert_close(out.float(), out2) - - -@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim")) -@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim")) -@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim")) -def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): - seq_dim = seq_dim - (seq_dim % 32) - hidden_dim = hidden_dim - (hidden_dim % 32) - batch_dim = batch_dim - (batch_dim % 2) - for i in range(25): - A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8) - out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) - iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) - out = F.igemm(A, B, out=iout) - - torch.testing.assert_close(out.float(), out2) - - -@pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim")) -@pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim")) -@pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim")) -@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) -def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): - def min_max(x): - maxA = torch.amax(x, dim=2, keepdim=True) - minA = torch.amin(x, dim=2, keepdim=True) - scale = (maxA - minA) / 2.0 - return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale - - seq_dim = seq_dim - (seq_dim % 16) - hidden_dim = hidden_dim - (hidden_dim % 16) - batch_dim = batch_dim - (batch_dim % 2) - errs = [] - relerrs = [] - errs2 = [] - relerrs2 = [] - for i in range(k): - A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") - if transpose: - B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") - else: - B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda") - Ac, minA, scale = min_max(A) - if transpose: - maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) - out = F.igemm(Ac, Bc.t()) - out2 = torch.matmul(A, B.t()) - offset = B.t().sum(0) * (minA + scale) - out = out.float() - out = (out * maxB.t() * scale / (127 * 127)) + offset - - maxA, Ac = quant_multi(A, dim=2) - out3 = F.igemm(Ac, Bc.t()) - out3 = mm_dequant(maxA, maxB.t(), out3) - else: - maxB, Bc = quant_multi(B, dim=0) - offset = B.sum(0) * (minA + scale) - out = F.igemm(Ac, Bc) - out2 = torch.matmul(A, B) - out = out.float() - out = (out * maxB * scale / (127 * 127)) + offset - - maxA, Ac = quant_multi(A, dim=2) - out3 = F.igemm(Ac, Bc) - out3 = mm_dequant(maxA, maxB, out3) - - std = out2.std() - out2 /= std - out /= std - out3 /= std - - err = torch.abs(out - out2) - relerr = err / (torch.abs(out2) + 1e-7) - - err2 = torch.abs(out3 - out2) - relerr2 = err2 / (torch.abs(out2) + 1e-7) - - errs.append(err.mean().item()) - relerrs.append(relerr.mean().item()) - errs2.append(err2.mean().item()) - relerrs2.append(relerr2.mean().item()) - # print(mean(errs)) - # print(mean(relerrs)) - # print(mean(errs2)) - # print(mean(relerrs2)) - assert mean(errs) < 0.015 - assert mean(relerrs) < 0.3 - - -@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4")) -@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) -def test_ibmm(dim1, dim2, dim3, dim4, transpose): - dim2 = dim2 - (dim2 % 16) - dim3 = dim3 - (dim3 % 16) - dim4 = dim4 - (dim4 % 16) - for i in range(k): - shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) - shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) - - if not transpose[0] and not transpose[1]: - out2 = torch.bmm(A.float(), B.float()) - out = F.igemm(A, B) - elif not transpose[0] and transpose[1]: - out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float()) - out = F.igemm(A, B.permute([0, 2, 1])) - elif transpose[0] and not transpose[1]: - out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) - out = F.igemm(A.permute([0, 2, 1]), B) - elif transpose[0] and transpose[1]: - out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) - out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) - torch.testing.assert_close(out.float(), out2.float()) - - -@pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) -@pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) -@pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) -def test_int8_linear_matmul(dim1, dim2, dim3, dim4, dims, ldb): - for i in range(k): - if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) - elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) - C1 = torch.matmul(A.float(), B.t().float()) - - C2 = F.int8_linear_matmul(A, B) - torch.testing.assert_close(C1, C2.float()) - - -@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) -@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) -def test_int8_linear_matmul_half(dim1, dim2, dim3, dim4, dims): - for i in range(k): - if dims == 2: - A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() - elif dims == 3: - A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() - B = torch.randn((dim4, dim3), device="cuda").half() - torch.nn.init.xavier_uniform_(B) - C1 = torch.matmul(A, B.t()) - - A = A.view(-1, A.shape[-1]) - - CA, _, statsA, _, _ = F.int8_double_quant(A) - CB, statsB, _ = F.int8_vectorwise_quant(B) - output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB) - - torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) - - -@pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) -@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) -@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) -def test_dequant_mm(dim1, dim4, dims, has_bias): - inner = 128 - bias = None - if has_bias: - bias = torch.randn(dim4, device="cuda", dtype=torch.float16) - - for i in range(1): - A = torch.randn(dim1, inner, device="cuda") - B = torch.randn(dim4, inner, device="cuda") - C1 = torch.matmul(A.half(), B.t().half()) - if has_bias: - C1 += bias - - A1, maxA = F.vectorwise_quant(A, dim=1) - B1, maxB = F.vectorwise_quant(B, dim=1) - - C2 = F.int8_linear_matmul(A1, B1) - - C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) +class TestIGEMMFunctional: + @pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2")) + @pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys()) + @pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched")) + def test_approx_igemm(self, dim1, dim2, quant_methods, batched): + dim1 = dim1 - (dim1 % 32) + dim2 = dim2 - (dim2 % 32) + errors = [] + relerrors = [] + # print("") + for i in range(5): + if batched: + A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") + B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda") + maxA, Ac = quant_methods[0](A, 2) + maxB, Bc = quant_methods[1](B, 1) + else: + A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda") + B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") + maxA, Ac = quant_methods[0](A, 1) + maxB, Bc = quant_methods[1](B, 0) + torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) + if batched: + out2 = torch.bmm(A, B) + C = torch.bmm(Ac.float(), Bc.float()) + else: + out2 = torch.mm(A, B) + C = F.igemm(Ac, Bc) + out = quant_methods[4](maxA, maxB, C) + std = out2.std() + out /= std + out2 /= std + err = torch.abs(out - out2) + relerr = err / torch.abs(out2) + errors.append(err.mean().item()) + relerrors.append(relerr.mean().item()) + # print(mean(errors)) + # print(mean(relerrors)) + + @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim")) + @pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim")) + @pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim")) + @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) + def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim): + hidden_dim = hidden_dim - (hidden_dim % 32) + batch_dim = batch_dim - (batch_dim % 16) + seq_dim = seq_dim - (seq_dim % 16) + for i in range(k): + shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) + ) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) + if not transpose[0] and not transpose[1]: + out2 = torch.matmul(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.matmul(A.float(), B.t().float()) + out = F.igemm(A, B.t()) + elif transpose[0] and not transpose[1]: + out2 = torch.matmul(A.t().float(), B.float()) + out = F.igemm(A.t(), B) + elif transpose[0] and transpose[1]: + out2 = torch.matmul(A.t().float(), B.t().float()) + out = F.igemm(A.t(), B.t()) + + torch.testing.assert_close(out.float(), out2) + + for i in range(k): + shapeA = (batch_dim, seq_dim, hidden_dim) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4)) + ) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) + if not transpose[0] and not transpose[1]: + out2 = torch.matmul(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.matmul(A.float(), B.t().float()) + out = F.igemm(A, B.t()) + + torch.testing.assert_close(out.float(), out2) + + @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim")) + @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim")) + @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim")) + def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): + seq_dim = seq_dim - (seq_dim % 32) + hidden_dim = hidden_dim - (hidden_dim % 32) + batch_dim = batch_dim - (batch_dim % 2) + for i in range(25): + A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8) + out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) + iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) + out = F.igemm(A, B, out=iout) + + torch.testing.assert_close(out.float(), out2) + + @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim")) + @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim")) + @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim")) + @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) + def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): + def min_max(x): + maxA = torch.amax(x, dim=2, keepdim=True) + minA = torch.amin(x, dim=2, keepdim=True) + scale = (maxA - minA) / 2.0 + return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale + + seq_dim = seq_dim - (seq_dim % 16) + hidden_dim = hidden_dim - (hidden_dim % 16) + batch_dim = batch_dim - (batch_dim % 2) + errs = [] + relerrs = [] + errs2 = [] + relerrs2 = [] + for i in range(k): + A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") + if transpose: + B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") + else: + B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda") + Ac, minA, scale = min_max(A) + if transpose: + maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) + out = F.igemm(Ac, Bc.t()) + out2 = torch.matmul(A, B.t()) + offset = B.t().sum(0) * (minA + scale) + out = out.float() + out = (out * maxB.t() * scale / (127 * 127)) + offset + + maxA, Ac = quant_multi(A, dim=2) + out3 = F.igemm(Ac, Bc.t()) + out3 = mm_dequant(maxA, maxB.t(), out3) + else: + maxB, Bc = quant_multi(B, dim=0) + offset = B.sum(0) * (minA + scale) + out = F.igemm(Ac, Bc) + out2 = torch.matmul(A, B) + out = out.float() + out = (out * maxB * scale / (127 * 127)) + offset + + maxA, Ac = quant_multi(A, dim=2) + out3 = F.igemm(Ac, Bc) + out3 = mm_dequant(maxA, maxB, out3) + + std = out2.std() + out2 /= std + out /= std + out3 /= std + + err = torch.abs(out - out2) + relerr = err / (torch.abs(out2) + 1e-7) + + err2 = torch.abs(out3 - out2) + relerr2 = err2 / (torch.abs(out2) + 1e-7) + + errs.append(err.mean().item()) + relerrs.append(relerr.mean().item()) + errs2.append(err2.mean().item()) + relerrs2.append(relerr2.mean().item()) + # print(mean(errs)) + # print(mean(relerrs)) + # print(mean(errs2)) + # print(mean(relerrs2)) + assert mean(errs) < 0.015 + assert mean(relerrs) < 0.3 + + @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2")) + @pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3")) + @pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4")) + @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) + def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + dim4 = dim4 - (dim4 % 16) + for i in range(k): + shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) + shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) + + if not transpose[0] and not transpose[1]: + out2 = torch.bmm(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float()) + out = F.igemm(A, B.permute([0, 2, 1])) + elif transpose[0] and not transpose[1]: + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) + out = F.igemm(A.permute([0, 2, 1]), B) + elif transpose[0] and transpose[1]: + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) + out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) + torch.testing.assert_close(out.float(), out2.float()) + + +class TestLLMInt8Functional: + @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) + @pytest.mark.parametrize("dim3", [499, 512], ids=id_formatter("dim3")) + @pytest.mark.parametrize("dim4", [512], ids=id_formatter("dim4")) + @pytest.mark.parametrize("dims", (2, 3), ids=id_formatter("dims")) + @pytest.mark.parametrize("ldb", (0,), ids=id_formatter("ldb")) + def test_int8_linear_matmul(self, dim1, dim2, dim3, dim4, dims, ldb): + for i in range(k): + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(torch.int8) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) + C1 = torch.matmul(A.float(), B.t().float()) + + C2 = F.int8_linear_matmul(A, B) + torch.testing.assert_close(C1, C2.float()) + + @pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [32], ids=id_formatter("dim2")) + @pytest.mark.parametrize("dim3", [32], ids=id_formatter("dim3")) + @pytest.mark.parametrize("dim4", [32], ids=id_formatter("dim4")) + @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) + def test_int8_linear_matmul_half(self, dim1, dim2, dim3, dim4, dims): + for i in range(k): + if dims == 2: + A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() + elif dims == 3: + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() + B = torch.randn((dim4, dim3), device="cuda").half() + torch.nn.init.xavier_uniform_(B) + C1 = torch.matmul(A, B.t()) + + A = A.view(-1, A.shape[-1]) + + CA, _, statsA, _, _ = F.int8_double_quant(A) + CB, statsB, _ = F.int8_vectorwise_quant(B) + output = F.int8_mm_dequant(F.int8_linear_matmul(CA, CB), statsA, statsB) + + torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + + @pytest.mark.parametrize("dim1", (64, 256), ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim4", (64, 1024), ids=id_formatter("dim4")) + @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) + @pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias")) + def test_dequant_mm(self, dim1, dim4, dims, has_bias): + inner = 128 + bias = None if has_bias: - C4 += bias - - # TODO: is something wrong here? If so, the problem goes deeper - # n = C1.numel() - # p = 0.06 - std = C1.std(0).view(1, -1) - C1 /= std - C4 /= std - # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) - # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" - - C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias) - C5 /= std - torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) - n = C5.numel() - assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) - - -@pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) -@pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp")) -def test_colrow_absmax(dim1, dim2, dims, threshold): - for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() - - assert dims == 2 - - row_stats1, _ = torch.abs(A.float()).max(1) - col_stats1, _ = torch.abs(A.float()).max(0) - - if threshold > 0.0: - A_truncated = A.clone() - A_truncated[torch.abs(A_truncated) >= threshold] = 0.0 - row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) - col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) - - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) - - nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten() - nnz_block_ptr1 = torch.zeros( - nnz_rows1_counts.shape[0] + 1, - dtype=nnz_rows1_counts.dtype, - device=nnz_rows1_counts.device, + bias = torch.randn(dim4, device="cuda", dtype=torch.float16) + + for i in range(1): + A = torch.randn(dim1, inner, device="cuda") + B = torch.randn(dim4, inner, device="cuda") + C1 = torch.matmul(A.half(), B.t().half()) + if has_bias: + C1 += bias + + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + C2 = F.int8_linear_matmul(A1, B1) + + C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) + if has_bias: + C4 += bias + + # TODO: is something wrong here? If so, the problem goes deeper + # n = C1.numel() + # p = 0.06 + std = C1.std(0).view(1, -1) + C1 /= std + C4 /= std + # assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) + # assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" + + C5 = F.int8_mm_dequant(C2, maxA, maxB, bias=bias) + C5 /= std + torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) + n = C5.numel() + assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n)) + + @pytest.mark.parametrize("dim1", [1 * 1024], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [1 * 1024], ids=id_formatter("dim2")) + @pytest.mark.parametrize("dims", (2,), ids=id_formatter("dims")) + @pytest.mark.parametrize("threshold", [0.0, 3.0], ids=id_formatter("decomp")) + def test_colrow_absmax(self, dim1, dim2, dims, threshold): + for i in range(k): + A = torch.randn(dim1, dim2, device="cuda").half() + + assert dims == 2 + + row_stats1, _ = torch.abs(A.float()).max(1) + col_stats1, _ = torch.abs(A.float()).max(0) + + if threshold > 0.0: + A_truncated = A.clone() + A_truncated[torch.abs(A_truncated) >= threshold] = 0.0 + row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) + col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) + + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) + + nnz_rows1_counts = (torch.abs(A) >= threshold).sum(1).flatten() + nnz_block_ptr1 = torch.zeros( + nnz_rows1_counts.shape[0] + 1, + dtype=nnz_rows1_counts.dtype, + device=nnz_rows1_counts.device, + ) + nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) + + torch.testing.assert_close(col_stats1_trunc, col_stats2) + torch.testing.assert_close(row_stats1_trunc, row_stats2) + # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2) + else: + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + assert nnz_block_ptr2 is None + torch.testing.assert_close(col_stats1, col_stats2) + torch.testing.assert_close(row_stats1, row_stats2) + + @pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) + def test_int8_double_quant(self, dim1, dim2): + for i in range(k): + A = torch.randn(dim1, dim2, device="cuda").half() + out_col1, Scol = F.vectorwise_quant(A, dim=0) + out_row1, Srow = F.vectorwise_quant(A, dim=1) + + CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A) + + # max difference is 1 due to rounding differences + torch.testing.assert_close(CA, out_row1, atol=1, rtol=0) + torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) + + n = CAt.numel() + num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() + + # allow for 1:500 error due to rounding differences + min_error = 1 / 500 + if num_not_close_cols > (min_error * n): + print( + f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" + ) + assert False + if num_not_close_rows > (min_error * n): + print( + f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" + ) + assert False + + torch.testing.assert_close(Srow.flatten().float(), statsA) + torch.testing.assert_close(Scol.flatten().float(), statsAt) + + @pytest.mark.parametrize( + ("dim1", "dim4", "inner"), + ( + pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") + for (dim1, dim4, inner) in zip( + (1, 8, 2048, 4096), + (2, 128, 2048, 4096), + (4, 256, 512, 4096), ) - nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) - - torch.testing.assert_close(col_stats1_trunc, col_stats2) - torch.testing.assert_close(row_stats1_trunc, row_stats2) - # torch.testing.assert_close(nnz_block_ptr1, nnz_block_ptr2) - else: - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) - assert nnz_block_ptr2 is None - torch.testing.assert_close(col_stats1, col_stats2) - torch.testing.assert_close(row_stats1, row_stats2) - - -@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2")) -def test_int8_double_quant(dim1, dim2): - for i in range(k): + ), + ) + def test_integrated_int8_linear_matmul(self, dim1, dim4, inner): + for i in range(k): + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() + + out1 = torch.matmul(A.half(), B.t().half()) + + C1a, stats1a, _ = F.int8_vectorwise_quant(A) + C2a, stats2a, _ = F.int8_vectorwise_quant(B) + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + torch.testing.assert_close(maxA.flatten().float(), stats1a) + torch.testing.assert_close(maxB.flatten().float(), stats2a) + torch.testing.assert_close(C1a, A1, rtol=0, atol=1) + torch.testing.assert_close(C2a, B1, rtol=0, atol=1) + + out2 = F.int8_linear_matmul(A1, B1) + + C2 = F.int8_linear_matmul(A1, B1) + + out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) + + err1 = torch.abs(out1 - out2).mean().item() + err2 = torch.abs(out1 - out3).mean().item() + assert err2 <= err1 * 1.025 + + @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) + def test_coo_double_quant(self, dim1, dim2): + threshold = 2.00 + for i in range(k): + A = torch.randn(dim1, dim2, device="cuda").half() + + idx = torch.abs(A) >= threshold + CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) + + if outlier_cols is not None: + A1 = A * idx + A2 = torch.zeros_like(A) + A1 + torch.testing.assert_close(A1, A2) + + A[:, outlier_cols] = 0 + A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) + + @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) + def test_coo_int8_vectorwise_quant(self, dim1, dim2): + threshold = 3.00 + for i in range(k): + A = torch.randn(dim1, dim2, device="cuda").half() + + idx = torch.abs(A) >= threshold + CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) + + if outlier_cols is not None: + A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + A[:, outlier_cols] = 0 + torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + + +class TestSpMMFunctional: + @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2")) + @pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) + def test_spmm_coo(self, dim1, dim2, transposed_B): + threshold = 1.5 + dim3 = torch.randint(32, 128, size=(1,)).item() + # dim3 = 17 + for i in range(k): + A = torch.randn(dim1, dim2).cuda().half() + if transposed_B: + B = torch.randn(dim3, dim2).cuda().half() + else: + B = torch.randn(dim2, dim3).cuda().half() + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A * idx + + if transposed_B: + out2 = F.spmm_coo(cooA, B.t()) + out1 = torch.matmul(A2, B.t()) + else: + out2 = F.spmm_coo(cooA, B) + out1 = torch.matmul(A2, B) + + assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) + + @pytest.mark.benchmark + def test_spmm_bench(self): + batch = 2 + model = 1024 * 1 + hidden = model * 4 + seq = 1024 + dim1 = batch * seq + dim2 = model + dim3 = hidden + threshold = 4 A = torch.randn(dim1, dim2, device="cuda").half() - out_col1, Scol = F.vectorwise_quant(A, dim=0) - out_row1, Srow = F.vectorwise_quant(A, dim=1) - - CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A) - - # max difference is 1 due to rounding differences - torch.testing.assert_close(CA, out_row1, atol=1, rtol=0) - torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) - - n = CAt.numel() - num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() - num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() - - # allow for 1:500 error due to rounding differences - min_error = 1 / 500 - if num_not_close_cols > (min_error * n): - print(f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}") - assert False - if num_not_close_rows > (min_error * n): - print(f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}") - assert False - - torch.testing.assert_close(Srow.flatten().float(), statsA) - torch.testing.assert_close(Scol.flatten().float(), statsAt) - - -@pytest.mark.parametrize( - ("dim1", "dim4", "inner"), - ( - pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") - for (dim1, dim4, inner) in zip( - (1, 8, 2048, 4096), - (2, 128, 2048, 4096), - (4, 256, 512, 4096), - ) - ), -) -def test_integrated_int8_linear_matmul(dim1, dim4, inner): - for i in range(k): - A = torch.randn(dim1, inner, device="cuda").half() - B = torch.randn(dim4, inner, device="cuda").half() - - out1 = torch.matmul(A.half(), B.t().half()) - - C1a, stats1a, _ = F.int8_vectorwise_quant(A) - C2a, stats2a, _ = F.int8_vectorwise_quant(B) - A1, maxA = F.vectorwise_quant(A, dim=1) - B1, maxB = F.vectorwise_quant(B, dim=1) - - torch.testing.assert_close(maxA.flatten().float(), stats1a) - torch.testing.assert_close(maxB.flatten().float(), stats2a) - torch.testing.assert_close(C1a, A1, rtol=0, atol=1) - torch.testing.assert_close(C2a, B1, rtol=0, atol=1) - - out2 = F.int8_linear_matmul(A1, B1) - - C2 = F.int8_linear_matmul(A1, B1) + B = torch.randn(dim2, dim3, device="cuda").half() + for i in range(10): + C1 = bnb.matmul(A, B.t()) - out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + C1 = bnb.matmul(A, B.t()) + torch.cuda.synchronize() + t8 = time.time() - t0 - err1 = torch.abs(out1 - out2).mean().item() - err2 = torch.abs(out1 - out3).mean().item() - assert err2 <= err1 * 1.025 + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + print(nnz / idx.numel()) + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + for i in range(10): + out2 = F.spmm_coo(cooA, B) -@pytest.mark.parametrize( - ("dim1", "dim4", "inner"), - ( - pytest.param(dim1, dim4, inner, id=f"{dim1=},{dim4=},{inner=}") - for (dim1, dim4, inner) in zip( - get_test_dims(1, 4 * 1024, n=6), - get_test_dims(1, 4 * 1024, n=6), - get_test_dims(1, 4 * 1024, n=6), - ) - ), -) -@pytest.mark.skip("Row scale has some bugs for ampere") -def test_igemmlt_row_scale(dim1, dim4, inner): - formatB = F.get_special_format_str() - err1, err2, err3 = [], [], [] - relerr1, relerr2 = [], [] - scale = 1 - for i in range(k): - A = torch.randn(dim1, inner, device="cuda").half() - B = torch.randn(dim4, inner, device="cuda").half() - torch.nn.init.xavier_uniform_(B) - C1 = torch.matmul(A, B.t()) - - out1 = torch.matmul(A.half(), B.t().half()) - - C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A) - CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") - A2, SA = F.nvidia_transform(C1a, "col32") - B2, SB = F.nvidia_transform(CB, formatB) - A1, maxA = F.vectorwise_quant(A, dim=1) - - c = 10.0 * inner * scale - row_scale = torch.ones_like(maxA) / c - outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale) - # C3, S = F.nvidia_transform(outC32, "row", state=SC) - C3 = outC32 - maxval = torch.abs(C3).max() - if maxval == 127: - scale = 1.5 + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + out2 = F.spmm_coo(cooA, B) + torch.cuda.synchronize() + tsp = time.time() - t0 + print(tsp, t8) + print(tsp / t8) + + @pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2")) + @pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) + @pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func")) + def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func): + out_func = getattr(torch, out_func) + + threshold = 3.3 + # threshold = 2.8 + # threshold = 0.0 + A = torch.randn(dim1, dim2, device="cuda").half() + if dtype == torch.float16: + B = torch.randn(dim2, dim2 * 4, device="cuda").half() + torch.nn.init.xavier_uniform_(B) else: - scale = maxval / 120 - out3 = C3 * maxA * absmaxB * c / (127 * 127) - - C4 = torch.matmul(C1a.float(), CB.float().t()) - - C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) - B2, SB = F.nvidia_transform(C2a, formatB) - outC32 = F.int8_linear_matmul(A2, B2) - out2 = F.int8_mm_dequant(outC32, stats1a, stats2a) - - CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector") - CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear") - - C = torch.matmul(CA.float(), CB.t().float()) - out4 = C * SA * SB / (127 * 127) - # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) - - # print('='*80) - # print(out1) - # print(out2) - # print(out3) + B = torch.randn(dim2, dim2 * 4, device="cuda").half() + torch.nn.init.xavier_uniform_(B) + B, SB = F.vectorwise_quant(B, quant_type="linear") + # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) + print("") + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A * idx + out1 = torch.matmul(A2.half(), B.half()) + out = out_func(out1.shape, dtype=torch.float16, device=out1.device) + out1 += out.clone() + out2 = F.spmm_coo_very_sparse(cooA, B, out=out) + # print(B) # print(out1) # print(out2) - # print(out3) - err1.append(torch.abs(out1 - out2).mean().item()) - err2.append(torch.abs(out1 - out3).mean().item()) - err3.append(torch.abs(out1 - out4).mean().item()) - - # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) - print("") - print(sum(err1) / len(err1)) - print(sum(err2) / len(err2)) - print(sum(err3) / len(err3)) - - -@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) -def test_coo_double_quant(dim1, dim2): - threshold = 2.00 - for i in range(k): - A = torch.randn(dim1, dim2, device="cuda").half() - - idx = torch.abs(A) >= threshold - CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) - - if outlier_cols is not None: - A1 = A * idx - A2 = torch.zeros_like(A) + A1 - torch.testing.assert_close(A1, A2) - - A[:, outlier_cols] = 0 - A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_close(A, A2, rtol=0.05, atol=1.5e-2) - - -@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) -def test_coo_int8_vectorwise_quant(dim1, dim2): - threshold = 3.00 - for i in range(k): + p = 200 / (2048 * 12288 * 4) + n = out1.numel() + count = math.ceil(p * n) + std = out1.std() + out1 /= std + out2 /= std + assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) + # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) + + idx_col = torch.randint(0, A2.shape[-1], size=(15,)) + + # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001) + + # Bt = torch.randn(dim2*4, dim2, device='cuda').half() + # torch.cuda.synchronize() + # t0 = time.time() + # print(A2.shape, B.shape) + # for i in range(100): + # #out3 = F.spmm_coo(cooA, Bt.t()) + # #out2 = F.spmm_coo(cooA, B) + # #out2 = F.spmm_coo_very_sparse(cooA, B) + # #out1 = torch.matmul(A, Bt.t()) + + # torch.cuda.synchronize() + # print(time.time() - t0) + + @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2")) + def test_integrated_sparse_decomp(self, dim1, dim2): + threshold = 3.0 + for _ in range(k): + A = torch.randn(dim1, dim2).cuda().half() + w1 = torch.randn(dim1, dim2).cuda().half() + out1 = torch.matmul(A, w1.t()) + + Cw1, statsw1, _ = F.int8_vectorwise_quant(w1) + CA, statsA, _ = F.int8_vectorwise_quant(A) + + out1_32 = F.int8_linear_matmul(CA, Cw1) + out2 = F.int8_mm_dequant(out1_32, statsA, statsw1) + + # CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) + CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold) + + out1_32 = F.int8_linear_matmul(CA, Cw1) + out3 = F.int8_mm_dequant(out1_32, statsA, statsw1) + + assert coo_tensor is not None + + out4 = F.spmm_coo(coo_tensor, w1.t()) + # idx = torch.unique(coo_tensor._indices()[1]).long() + # out4 = torch.matmul(A, w1.t()) + out5 = out3 + out4 + + err1 = torch.abs(out1 - out2).mean().item() + err2 = torch.abs(out1 - out5).mean().item() + assert err2 < err1 + + @pytest.mark.parametrize("dim1", [1 * 2048]) + @pytest.mark.parametrize("dim2", [2048]) + @pytest.mark.parametrize("dtype", [torch.int8]) + def test_spmm_coo_dequant(self, dim1, dim2, dtype): + threshold = 6.0 + # threshold = 2.8 + # threshold = 0.0 A = torch.randn(dim1, dim2, device="cuda").half() + B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16) + torch.nn.init.xavier_uniform_(B) + Bt = B.t().contiguous() - idx = torch.abs(A) >= threshold - CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) - - if outlier_cols is not None: - A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - A[:, outlier_cols] = 0 - torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B) + rowidx = torch.randint(0, A.shape[-1], size=(15,)) -@pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) -def test_spmm_coo(dim1, dim2, transposed_B): - threshold = 1.5 - dim3 = torch.randint(32, 128, size=(1,)).item() - # dim3 = 17 - for i in range(k): - A = torch.randn(dim1, dim2).cuda().half() - if transposed_B: - B = torch.randn(dim3, dim2).cuda().half() - else: - B = torch.randn(dim2, dim3).cuda().half() + A[:, rowidx] = 8.0 idx = torch.abs(A) >= threshold nnz = (idx == 1).sum().item() @@ -751,712 +947,381 @@ def test_spmm_coo(dim1, dim2, transposed_B): values = A[idx] cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) A2 = A * idx - - if transposed_B: - out2 = F.spmm_coo(cooA, B.t()) - out1 = torch.matmul(A2, B.t()) - else: - out2 = F.spmm_coo(cooA, B) - out1 = torch.matmul(A2, B) - - assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) - - -@pytest.mark.benchmark -def test_spmm_bench(): - batch = 2 - model = 1024 * 1 - hidden = model * 4 - seq = 1024 - dim1 = batch * seq - dim2 = model - dim3 = hidden - threshold = 4 - A = torch.randn(dim1, dim2, device="cuda").half() - B = torch.randn(dim2, dim3, device="cuda").half() - for i in range(10): - C1 = bnb.matmul(A, B.t()) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(k): - C1 = bnb.matmul(A, B.t()) - torch.cuda.synchronize() - t8 = time.time() - t0 - - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - print(nnz / idx.numel()) - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - - for i in range(10): - out2 = F.spmm_coo(cooA, B) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(k): - out2 = F.spmm_coo(cooA, B) - torch.cuda.synchronize() - tsp = time.time() - t0 - print(tsp, t8) - print(tsp / t8) - - -@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2")) -def test_integrated_sparse_decomp(dim1, dim2): - threshold = 3.0 - for _ in range(k): - A = torch.randn(dim1, dim2).cuda().half() - w1 = torch.randn(dim1, dim2).cuda().half() - out1 = torch.matmul(A, w1.t()) - - Cw1, statsw1, _ = F.int8_vectorwise_quant(w1) - CA, statsA, _ = F.int8_vectorwise_quant(A) - - out1_32 = F.int8_linear_matmul(CA, Cw1) - out2 = F.int8_mm_dequant(out1_32, statsA, statsw1) - - # CA, statsA, outlier_cols = F.int8_vectorwise_quant(A, threshold=threshold) - CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold) - - out1_32 = F.int8_linear_matmul(CA, Cw1) - out3 = F.int8_mm_dequant(out1_32, statsA, statsw1) - - assert coo_tensor is not None - - out4 = F.spmm_coo(coo_tensor, w1.t()) - # idx = torch.unique(coo_tensor._indices()[1]).long() - # out4 = torch.matmul(A, w1.t()) - out5 = out3 + out4 - - err1 = torch.abs(out1 - out2).mean().item() - err2 = torch.abs(out1 - out5).mean().item() - assert err2 < err1 - - -def test_matmuls(): - a = torch.randn(256, 512).half().cuda() - b = torch.randn(256, 512).half().cuda() - c1 = torch.matmul(a, b.t()) - c2 = bnb.matmul(a, b) - c3 = bnb.matmul_cublas(a, b.t()) - - err1 = torch.abs(c1 - c2).mean().item() - err2 = torch.abs(c1 - c3).mean().item() - assert err1 < 0.2 - assert err2 < 0.2 - print(err1, err2) - - -@pytest.mark.parametrize("dim1", [1 * 2048], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [12288], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dtype", [torch.float16], ids=describe_dtype) -@pytest.mark.parametrize("out_func", ["zeros", "ones"], ids=id_formatter("out_func")) -def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): - out_func = getattr(torch, out_func) - - threshold = 3.3 - # threshold = 2.8 - # threshold = 0.0 - A = torch.randn(dim1, dim2, device="cuda").half() - if dtype == torch.float16: - B = torch.randn(dim2, dim2 * 4, device="cuda").half() - torch.nn.init.xavier_uniform_(B) - else: - B = torch.randn(dim2, dim2 * 4, device="cuda").half() - torch.nn.init.xavier_uniform_(B) - B, SB = F.vectorwise_quant(B, quant_type="linear") - # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) - - print("") - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - out1 = torch.matmul(A2.half(), B.half()) - out = out_func(out1.shape, dtype=torch.float16, device=out1.device) - out1 += out.clone() - out2 = F.spmm_coo_very_sparse(cooA, B, out=out) - # print(B) - # print(out1) - # print(out2) - p = 200 / (2048 * 12288 * 4) - n = out1.numel() - count = math.ceil(p * n) - std = out1.std() - out1 /= std - out2 /= std - assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) - # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) - - idx_col = torch.randint(0, A2.shape[-1], size=(15,)) - - # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001) - - # Bt = torch.randn(dim2*4, dim2, device='cuda').half() - # torch.cuda.synchronize() - # t0 = time.time() - # print(A2.shape, B.shape) - # for i in range(100): - # #out3 = F.spmm_coo(cooA, Bt.t()) - # #out2 = F.spmm_coo(cooA, B) - # #out2 = F.spmm_coo_very_sparse(cooA, B) - # #out1 = torch.matmul(A, Bt.t()) - - # torch.cuda.synchronize() - # print(time.time() - t0) - - -def test_coo2csr(): - threshold = 1 - A = torch.randn(128, 128).half().cuda() - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - csrA = F.coo2csr(cooA) - counts = csrA.rowptr[1:] - csrA.rowptr[:-1] - assert counts.numel() == A.shape[0] - - torch.testing.assert_close(counts.long(), (A2 != 0).sum(1)) - idx = A2 != 0 - torch.testing.assert_close(A2[idx], csrA.values) - - -def test_coo2csc(): - threshold = 1 - A = torch.randn(128, 128).half().cuda() - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - cscA = F.coo2csc(cooA) - counts = cscA.colptr[1:] - cscA.colptr[:-1] - assert counts.numel() == A.shape[1] - - torch.testing.assert_close(counts.long(), (A2 != 0).sum(0)) - # torch uses row-major -> use transpose to transfer to col-major - idx = A2.t() != 0 - torch.testing.assert_close(A2.t()[idx], cscA.values) - - -@pytest.mark.parametrize("dim1", [1 * 2048]) -@pytest.mark.parametrize("dim2", [2048]) -@pytest.mark.parametrize("dtype", [torch.int8]) -def test_spmm_coo_dequant(dim1, dim2, dtype): - threshold = 6.0 - # threshold = 2.8 - # threshold = 0.0 - A = torch.randn(dim1, dim2, device="cuda").half() - B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16) - torch.nn.init.xavier_uniform_(B) - Bt = B.t().contiguous() - - CB, CBt, statsB, statsBt, coo_tensor = F.int8_double_quant(B) - - rowidx = torch.randint(0, A.shape[-1], size=(15,)) - - A[:, rowidx] = 8.0 - - idx = torch.abs(A) >= threshold - nnz = (idx == 1).sum().item() - rows, cols = torch.where(idx) - values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A * idx - out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) - out1 = torch.matmul(A2, B.half()) - out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) - out3 = out3 * statsBt.half() / 127 - - values, counts = torch.unique(cooA.rowidx, return_counts=True) - offset = counts.cumsum(0).int() - max_count, max_idx = torch.sort(counts, descending=True) - print(torch.median(max_count.float())) - - torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001) - - p = 200 / (2048 * 12288 * 4) - n = out1.numel() - count = math.ceil(p * n) - assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) - - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(100): - # out2 = F.spmm_coo_very_sparse(cooA, B) - # torch.cuda.synchronize() - # print('fp16', time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out2 = F.spmm_coo(cooA, B) - torch.cuda.synchronize() - print("cusparse fp16", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out2 = F.spmm_coo_very_sparse(cooA, CBt) - torch.cuda.synchronize() - print("int8", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) - torch.cuda.synchronize() - print("int8+dequant", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out2 = torch.matmul(A, B) - torch.cuda.synchronize() - print("matmul", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out1 = bnb.matmul(A, Bt) - out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) - out = out1 + out2 - torch.cuda.synchronize() - print("sparse+ matmul", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out1 = bnb.matmul(A, Bt) - torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) - torch.cuda.synchronize() - print("partial matmul", time.time() - t0) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - out1 = bnb.matmul(A, Bt) - torch.cuda.synchronize() - print("partial matmul", time.time() - t0) - - -def test_zeropoint(): - def quant_zp(x): - dtype = x.dtype - x = x.float() - dyna = x.max() - x.min() - if dyna == 0: - dyna = 1 - qx = 254.0 / dyna - minx = x.min() - # zpx = torch.round(minx* qx) - # zpx = 127 - torch.round(x.max()* qx) - zpx = torch.round(x.min() * qx) - 127 - x = (qx * x) + zpx - return x, qx, zpx - - batch = 2 - seq = 512 - model = 1024 - hidden = 4 * model - A = torch.randn(batch * seq, model, device="cuda").half() * 0.1 - B = torch.randn(model, hidden, device="cuda").half() * 0.1 - - C0 = torch.matmul(A, B) - - # A, SA = F.vectorwise_quant(A, quant_type='linear') - # B, SB = F.vectorwise_quant(B, quant_type='linear') - A = A.float() - B = B.float() - - C1 = torch.matmul(A, B) - C3 = bnb.matmul(A.half(), B.t().contiguous().half()) - - zp = 1 - # C2 = torch.matmul(A-zp, B) - # C2 += B.sum(0).view(1, -1)*zp - C2 = torch.matmul(A, B - zp) - C2 -= A.sum(1).view(-1, 1) * zp - - ca, cqa, cza = quant_zp(A) - # print(ca.min(), ca.max()) - # print((ca - cza).min(), (ca - cza).max()) - - zp = 1 - scale = 2.0 - C5 = torch.matmul((A * scale) - zp, B) - C5 += B.sum(0) * zp - C5 /= scale - - CA, qa, zpa = quant_zp(A) - C4 = torch.matmul(CA, B) - C4 -= B.sum(0) * zpa - C4 /= qa - - zpb = 1 - zpa = 1 - qa = 2 - qb = 2 - C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb) - C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb) - C6 -= zpa * zpb * A.shape[1] - C6 /= qa * qb - - CA, qa, zpa = quant_zp(A) - CB, qb, zpb = quant_zp(B) - C7 = torch.matmul(CA, CB) - C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb) - C7 -= zpa * zpb * A.shape[1] - C7 /= qa * qb - - # print("") - # print(C0.flatten()[:10]) - # print(C1.flatten()[:10]) - # print(C2.flatten()[:10]) - # print(C3.flatten()[:10]) - # print(C5.flatten()[:10]) - # print(C6.flatten()[:10]) - # print(C7.flatten()[:10]) - err1 = torch.abs(C1 - C2).mean().item() - err2 = torch.abs(C1 - C3).mean().item() - err3 = torch.abs(C1 - C4).mean().item() - err4 = torch.abs(C1 - C5).mean().item() - err5 = torch.abs(C1 - C6).mean().item() - err6 = torch.abs(C1 - C7).mean().item() - print(err1, err2, err3, err4, err5, err6) - - -@pytest.mark.deprecated -def test_extract_outliers(): - for i in range(k): - shapeA = (4096, 4096 * 4) - idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() - # idx = torch.Tensor([0]).int().cuda() - A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) - outliers1 = A[:, idx.long()] - - CA, SA = F.transform(A, "col_turing") - - outliers2 = F.extract_outliers(CA, SA, idx) + out1 = torch.matmul(A2, B.half()) + out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) + out3 = out3 * statsBt.half() / 127 + + values, counts = torch.unique(cooA.rowidx, return_counts=True) + offset = counts.cumsum(0).int() + max_count, max_idx = torch.sort(counts, descending=True) + print(torch.median(max_count.float())) + + torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001) + + p = 200 / (2048 * 12288 * 4) + n = out1.numel() + count = math.ceil(p * n) + assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) + + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(100): + # out2 = F.spmm_coo_very_sparse(cooA, B) + # torch.cuda.synchronize() + # print('fp16', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo(cooA, B) + torch.cuda.synchronize() + print("cusparse fp16", time.time() - t0) - assert outliers2.shape[0] == shapeA[0] - assert outliers2.shape[1] == idx.numel() + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo_very_sparse(cooA, CBt) + torch.cuda.synchronize() + print("int8", time.time() - t0) - torch.testing.assert_close(outliers1, outliers2) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + torch.cuda.synchronize() + print("int8+dequant", time.time() - t0) - CA, SA = F.transform(A, "col_ampere") + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = torch.matmul(A, B) + torch.cuda.synchronize() + print("matmul", time.time() - t0) - outliers2 = F.extract_outliers(CA, SA, idx) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out = out1 + out2 + torch.cuda.synchronize() + print("sparse+ matmul", time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) + torch.cuda.synchronize() + print("partial matmul", time.time() - t0) - assert outliers2.shape[0] == shapeA[0] - assert outliers2.shape[1] == idx.numel() + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + torch.cuda.synchronize() + print("partial matmul", time.time() - t0) - torch.testing.assert_close(outliers1, outliers2) +class TestSparseTensorFunctional: + def test_coo2csr(self): + threshold = 1 + A = torch.randn(128, 128).half().cuda() + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A * idx + csrA = F.coo2csr(cooA) + counts = csrA.rowptr[1:] - csrA.rowptr[:-1] + assert counts.numel() == A.shape[0] -def test_blockwise_cpu_large(): - diffs = [] - reldiffs = [] - batch = 128 - seq = 128 - for hidden in [128]: # , 14336]: - for blocksize in [4096, 16384]: - for i in range(2): - A1 = torch.randn(batch, seq, hidden, device="cpu") - t0 = time.time() - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - print(time.time() - t0) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - assert diffs[-1] < 0.011 - # print(sum(diffs)/len(diffs)) - # print(sum(reldiffs)/len(reldiffs)) + torch.testing.assert_close(counts.long(), (A2 != 0).sum(1)) + idx = A2 != 0 + torch.testing.assert_close(A2[idx], csrA.values) + def test_coo2csc(self): + threshold = 1 + A = torch.randn(128, 128).half().cuda() + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A * idx + cscA = F.coo2csc(cooA) + counts = cscA.colptr[1:] - cscA.colptr[:-1] + assert counts.numel() == A.shape[1] -def test_fp8_quant(): - for e_bits in range(1, 7): - p_bits = 7 - e_bits - code = F.create_fp8_map(True, e_bits, p_bits).cuda() + torch.testing.assert_close(counts.long(), (A2 != 0).sum(0)) + # torch uses row-major -> use transpose to transfer to col-major + idx = A2.t() != 0 + torch.testing.assert_close(A2.t()[idx], cscA.values) - abserr = [] - relerr = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, SC = F.quantize_blockwise(A1, code=code) - A2 = F.dequantize_blockwise(C, SC) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - abserr.append(diff.mean().item()) - relerr.append(reldiff.mean().item()) - # assert diff < 0.0075 - # print(sum(abserr)/len(abserr)) - # print(sum(relerr)/len(relerr)) - - abserr = [] - relerr = [] - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, SC = F.quantize_blockwise(A1, code=code) - A2 = F.dequantize_blockwise(C, SC) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - abserr.append(diff.mean().item()) - relerr.append(reldiff.mean().item()) - # assert diff < 0.0075 - # print(sum(abserr)/len(abserr)) - # print(sum(relerr)/len(relerr)) - - abserr = [] - relerr = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, SC = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, SC) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - abserr.append(diff.mean().item()) - relerr.append(reldiff.mean().item()) - # assert diff < 0.0075 - # print(3, sum(abserr)/len(abserr)) - # print(3, sum(relerr)/len(relerr)) - - -def test_few_bit_quant(): - # print('') - for bits in range(2, 9): - # print('='*30, bits, '='*30) - for method in ["linear", "fp8", "dynamic", "quantile"]: - abserrs = [] - relerrs = [] - code = None - if method == "linear": - code = F.create_linear_map(True, total_bits=bits).cuda() - elif method == "fp8": - ebits = math.ceil(bits / 2) - pbits = bits - ebits - 1 - code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - elif method == "dynamic": - code = F.create_dynamic_map(True, bits - 0, bits).cuda() - elif method == "quantile": - values = torch.randn(2048, 2048, device="cuda") - code = F.create_quantile_map(values, bits).cuda() - # for some data types we have no zero - # for some data types we have one zero - # for some data types we have two zeros - assert torch.unique(code).numel() in [2**bits, 2**bits - 1], f"bits: {bits}, method: {method}" - # print(method, (code==0).sum()) - assert code.numel() == 256 - for i in range(10): - values = torch.randn(1, 32, device="cuda") - values /= values.abs().max() - # values[values.abs() < 1e-6] += 1e-5 - - q1 = [] - v1 = [] - for v in values[0]: - idx = torch.abs(v - code).argmin() - q1.append(idx.item()) - v1.append(code[idx].item()) - - q1 = torch.Tensor(q1).cuda() - v1 = torch.Tensor(v1).cuda() - - q2, S2 = F.quantize_blockwise(values, code=code) - v2 = F.dequantize_blockwise(q2, S2) - - idx = torch.isclose(q1.int(), q2.int()) - err2 = torch.abs(v2 - values) - abserrs.append(err2.mean().item()) - relerrs.append((err2 / (1e-10 + values).abs()).mean().item()) - if idx.sum(): - # some weird cases - err1 = torch.abs(v1 - values).mean() - # assert err2.mean() <= err1 - else: - torch.testing.assert_close(q1, q2) - # print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) - # assert False - - -def test_kbit_quantile_estimation(): - for i in range(100): - data = torch.randn(1024, 1024, device="cuda") - for bits in range(2, 9): - p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits) - val1 = torch.Tensor(norm.ppf(p)).cuda() - val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) - err = torch.abs(val1 - val2).mean() - assert err < 0.038 - - for i in range(100): - data = torch.randn(1024, 1024, device="cuda") - for bits in range(2, 4): - total_values = 2**bits - 1 - p = np.linspace(0, 1, 2 * total_values + 1) - idx = np.arange(1, 2 * total_values + 1, 2) - p = p[idx] - offset = 1 / (2 * total_values) - p = np.linspace(offset, 1 - offset, total_values) - val1 = torch.Tensor(norm.ppf(p)).cuda() - val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1) - err = torch.abs(val1 - val2).mean() - assert err < 0.035 - - -@pytest.mark.benchmark -def test_bench_dequantization(): - a = torch.rand(1024, 1024, device="cuda").half() - code = F.create_fp8_map(True, 3, 0, 4).cuda() - qa, SA = F.quantize_blockwise(a, code=code) - print(qa.max()) - - max_theoretical_mu = 1024 * 1024 * 2 / 1024**3 / 672 * 1000 * 1000 - # print(max_theoretical_mu) - - torch.cuda.synchronize() - t0 = time.time() - for i in range(100): - qa, SA = F.quantize_blockwise(a) - torch.cuda.synchronize() - # print((time.time()-t0)/1e6) - - -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) -@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) -@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) -def test_4bit_quant(dtype, quant_type, blocksize): - vals = list(product([0, 1], repeat=4)) - - code = {} - for bits in vals: - result = 0 - bias = 3 - sign, e1, e2, p1 = bits - idx = sign * 8 + e1 * 4 + e2 * 2 + p1 * 1 - sign = -1.0 if sign else 1.0 - exp = e1 * 2 + e2 * 1 - if exp == 0: - # sub-normal - if p1 == 0: - result = 0 - else: - result = sign * 0.0625 - else: - # normal - exp = 2 ** (-exp + bias + 1) - frac = 1.5 if p1 else 1.0 - result = sign * exp * frac - code[idx] = result - - A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) - qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) - A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) - - err = (A1 - A2).abs().float() - relerr = (err / (A1.abs().float() + 1e-8)).mean() - idx = err > 1.0 - err = err.mean() - - assert A2.dtype == dtype - - # With larger block sizes, we can expect this to blow up. - # At blocksize>=1024, don't even bother looking at relerr. - if blocksize <= 64: - assert err.item() < 0.1 - assert relerr.item() < 0.28 - elif blocksize <= 256: - assert err.item() < 0.11 - assert relerr.item() < 0.30 - elif blocksize <= 512: - assert err.item() < 0.12 - assert relerr.item() < 0.31 - elif quant_type == "fp4": - # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 - assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 - else: - # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 - assert err.item() < math.log2(blocksize) * 8e-2 - - -@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) -def test_4bit_compressed_stats(quant_type): - for blocksize in [128, 64]: - errs1 = [] - errs2 = [] - for i in range(10): - A1 = torch.randn(1024, 1024, device="cuda").half() - q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) - q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) - A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) - A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) +class TestQuantize4BitFunctional: + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) + def test_4bit_quant(self, dtype, quant_type, blocksize): + A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype) + qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) - err = (A1 - A2).abs().float() - relerr = (err / (A1.abs().float() + 1e-15)).mean() - err = err.mean() + err = (A1 - A2).abs().float() + relerr = (err / (A1.abs().float() + 1e-8)).mean() + err = err.mean() - errs1.append(err.item()) + assert A2.dtype == dtype - assert err.item() < 0.11 + # With larger block sizes, we can expect this to blow up. + # At blocksize>=1024, don't even bother looking at relerr. + if blocksize <= 64: + assert err.item() < 0.1 assert relerr.item() < 0.28 - - err = (A1 - A3).abs().float() - relerr = (err / (A1.abs().float() + 1e-15)).mean() - err = err.mean() - - errs2.append(err.item()) - + elif blocksize <= 256: assert err.item() < 0.11 - assert relerr.item() < 0.28 + assert relerr.item() < 0.30 + elif blocksize <= 512: + assert err.item() < 0.12 + assert relerr.item() < 0.31 + elif quant_type == "fp4": + # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56 + assert err.item() < 0.08 + math.log2(blocksize) * 4e-2 + else: + # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96 + assert err.item() < math.log2(blocksize) * 8e-2 + + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + def test_4bit_compressed_stats(self, quant_type): + for blocksize in [128, 64]: + errs1 = [] + errs2 = [] + for i in range(10): + A1 = torch.randn(1024, 1024, device="cuda").half() + q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) + A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) + + err = (A1 - A2).abs().float() + relerr = (err / (A1.abs().float() + 1e-15)).mean() + err = err.mean() + + errs1.append(err.item()) + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + err = (A1 - A3).abs().float() + relerr = (err / (A1.abs().float() + 1e-15)).mean() + err = err.mean() + + errs2.append(err.item()) + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + # print(sum(errs1)/len(errs1), blocksize, quant_type) + # print(sum(errs2)/len(errs2), blocksize, quant_type) + + # @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) + @pytest.mark.parametrize("quant_type", ["nf4"]) + @pytest.mark.benchmark + def test_bench_4bit_dequant(self, quant_type): + blocksize = 256 + a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half() + qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) + + input_size = a.numel() / 2 + output_size = a.numel() * 2 + num_bytes = input_size + output_size + GB = num_bytes / 1e9 + max_theoretical_s = GB / 768 + # print(max_theoretical_s*1e6) + b = torch.randn(128, 1024 * 12, device="cuda").half() + + iters = 100 + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + # b.copy_(a) + torch.cuda.synchronize() + # print((time.time()-t0)/iters*1e6) + + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(iters): + # torch.matmul(b, a.t()) + # torch.cuda.synchronize() + # print((time.time()-t0)/iters*1e6) + + @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") + @pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) + @pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) + @pytest.mark.parametrize( + "quant_storage", + [torch.uint8, torch.float16, torch.bfloat16, torch.float32], + ids=describe_dtype, + ) + def test_gemv_4bit(self, dtype, storage_type, quant_storage, double_quant, kind): + for dim in [128, 256, 512, 1024]: + # for dim in [4*1024]: + # for dim in [1*16]: + errs1 = [] + errs2 = [] + errs3 = [] + relerrs1 = [] + relerrs2 = [] + relerrs3 = [] + max_errs1 = [] + max_errs2 = [] + max_errs3 = [] + + for i in range(100): + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=storage_type, + compress_statistics=double_quant, + quant_storage=quant_storage, + ) + C3 = torch.matmul(A, B.t()) + C2 = F.gemv_4bit(A, qB.t(), state=state) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) + + err1 = (C1 - C2).abs().float() + err2 = (C3 - C2).abs().float() + err3 = (C3 - C1).abs().float() + + mag1 = torch.abs(C1).float() + 1e-5 + mag2 = torch.abs(C3).float() + 1e-5 + mag3 = torch.abs(C3).float() + 1e-5 + + relerr1 = err1 / mag1 + relerr2 = err2 / mag2 + relerr3 = err3 / mag3 + + max_err1 = err1.max() + max_err2 = err2.max() + max_err3 = err3.max() + + errs1.append(err1.mean().item()) + errs2.append(err2.mean().item()) + errs3.append(err3.mean().item()) + + relerrs1.append(relerr1.mean().item()) + relerrs2.append(relerr2.mean().item()) + relerrs3.append(relerr3.mean().item()) + + max_errs1.append(max_err1.item()) + max_errs2.append(max_err2.item()) + max_errs3.append(max_err3.item()) + + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False) + err1 = sum(errs1) / len(errs1) / math.sqrt(dim) + err2 = sum(errs2) / len(errs2) / math.sqrt(dim) + err3 = sum(errs3) / len(errs3) / math.sqrt(dim) + relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim) + relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim) + relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim) + maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim) + maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim) + maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim) + absratio = err2 / err3 + relratio = relerr2 / relerr3 + maxratio = relerr2 / relerr3 + + # for debugging if the tests fails + # + # print('='*80) + # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') + # print(C1.flatten()[-20:]) + # print(C2.flatten()[-20:]) + # print(f'inference vs training abs: {err1}') + # print(f'inference vs training rel: {relerr1}') + # print(f'inference vs training max: {maxerr1}') + # print(f'inference vs training vs torch err ratio abs: {absratio}') + # print(f'inference vs training vs torch err ratio rel: {relratio}') + # print(f'inference vs training vs torch err ratio max: {maxratio}') + if dtype == torch.float16: + if dim <= 512: + assert err1 < 7e-5 + assert relerr1 < 0.0008 + else: + assert err1 < 6e-5 + assert relerr1 < 2e-4 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.005 and relratio > 0.995 + assert maxratio < 1.005 and maxratio > 0.995 + elif dtype == torch.float32: + if dim <= 512: + assert err1 < 5e-8 + assert relerr1 < 1e-6 + assert maxerr1 < 1e-7 + else: + assert err1 < 5e-8 + assert relerr1 < 8e-6 + assert maxerr1 < 1e-7 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.005 and relratio > 0.995 + assert maxratio < 1.005 and maxratio > 0.995 + elif dtype == torch.bfloat16: + if dim <= 512: + assert err1 < 6e-4 + assert relerr1 < 0.007 + assert maxerr1 < 0.015 + else: + assert err1 < 2e-4 + assert relerr1 < 0.002 + assert maxerr1 < 0.0012 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.04 and relratio > 0.96 + assert maxratio < 1.02 and maxratio > 0.98 + + @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) + @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) + def test_gemv_eye_4bit(self, storage_type, dtype, double_quant): + dims = 10 + torch.random.manual_seed(np.random.randint(0, 412424242)) + dims = get_test_dims(0, 8192, n=dims) + dims = [dim + (64 - (dim % 64)) for dim in dims] + # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: + for dim in dims: + A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda") + B = torch.eye(dim, dtype=dtype, device="cuda") + + qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) + C3 = torch.matmul(A, B.t()) + C2 = bnb.matmul_4bit(A, qB.t(), state) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) - # print(sum(errs1)/len(errs1), blocksize, quant_type) - # print(sum(errs2)/len(errs2), blocksize, quant_type) - - -# @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) -@pytest.mark.parametrize("quant_type", ["nf4"]) -@pytest.mark.benchmark -def test_bench_4bit_dequant(quant_type): - blocksize = 256 - a = torch.rand(1024 * 12 * 4, 1024 * 12, device="cuda").half() - qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) - - input_size = a.numel() / 2 - output_size = a.numel() * 2 - num_bytes = input_size + output_size - GB = num_bytes / 1e9 - max_theoretical_s = GB / 768 - # print(max_theoretical_s*1e6) - b = torch.randn(128, 1024 * 12, device="cuda").half() - - iters = 100 - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) - # b.copy_(a) - torch.cuda.synchronize() - # print((time.time()-t0)/iters*1e6) - - # torch.cuda.synchronize() - # t0 = time.time() - # for i in range(iters): - # torch.matmul(b, a.t()) - # torch.cuda.synchronize() - # print((time.time()-t0)/iters*1e6) + torch.testing.assert_close(A, C3) + torch.testing.assert_close(A, C1) + torch.testing.assert_close(A, C2) + # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) + # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) def test_normal_map_tree(): @@ -1474,146 +1339,6 @@ def test_normal_map_tree(): # print(pivots) -@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") -@pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("kind", ["fc1", "fc2", "attn", "attn_packed"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize( - "quant_storage", - [torch.uint8, torch.float16, torch.bfloat16, torch.float32], - ids=describe_dtype, -) -def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind): - for dim in [128, 256, 512, 1024]: - # for dim in [4*1024]: - # for dim in [1*16]: - errs1 = [] - errs2 = [] - errs3 = [] - relerrs1 = [] - relerrs2 = [] - relerrs3 = [] - max_errs1 = [] - max_errs2 = [] - max_errs3 = [] - - for i in range(100): - if kind == "fc1": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "fc2": - A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") - B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "attn": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "attn_packed": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - - qB, state = F.quantize_4bit( - B, - quant_type=storage_type, - compress_statistics=double_quant, - quant_storage=quant_storage, - ) - C3 = torch.matmul(A, B.t()) - C2 = F.gemv_4bit(A, qB.t(), state=state) - A.requires_grad = True - C1 = bnb.matmul_4bit(A, qB.t(), state) - - err1 = (C1 - C2).abs().float() - err2 = (C3 - C2).abs().float() - err3 = (C3 - C1).abs().float() - - mag1 = torch.abs(C1).float() + 1e-5 - mag2 = torch.abs(C3).float() + 1e-5 - mag3 = torch.abs(C3).float() + 1e-5 - - relerr1 = err1 / mag1 - relerr2 = err2 / mag2 - relerr3 = err3 / mag3 - - max_err1 = err1.max() - max_err2 = err2.max() - max_err3 = err3.max() - - errs1.append(err1.mean().item()) - errs2.append(err2.mean().item()) - errs3.append(err3.mean().item()) - - relerrs1.append(relerr1.mean().item()) - relerrs2.append(relerr2.mean().item()) - relerrs3.append(relerr3.mean().item()) - - max_errs1.append(max_err1.item()) - max_errs2.append(max_err2.item()) - max_errs3.append(max_err3.item()) - - c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 - - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False) - err1 = sum(errs1) / len(errs1) / math.sqrt(dim) - err2 = sum(errs2) / len(errs2) / math.sqrt(dim) - err3 = sum(errs3) / len(errs3) / math.sqrt(dim) - relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim) - relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim) - relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim) - maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim) - maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim) - maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim) - absratio = err2 / err3 - relratio = relerr2 / relerr3 - maxratio = relerr2 / relerr3 - - # for debugging if the tests fails - # - # print('='*80) - # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') - # print(C1.flatten()[-20:]) - # print(C2.flatten()[-20:]) - # print(f'inference vs training abs: {err1}') - # print(f'inference vs training rel: {relerr1}') - # print(f'inference vs training max: {maxerr1}') - # print(f'inference vs training vs torch err ratio abs: {absratio}') - # print(f'inference vs training vs torch err ratio rel: {relratio}') - # print(f'inference vs training vs torch err ratio max: {maxratio}') - if dtype == torch.float16: - if dim <= 512: - assert err1 < 7e-5 - assert relerr1 < 0.0008 - else: - assert err1 < 6e-5 - assert relerr1 < 2e-4 - assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.005 and relratio > 0.995 - assert maxratio < 1.005 and maxratio > 0.995 - elif dtype == torch.float32: - if dim <= 512: - assert err1 < 5e-8 - assert relerr1 < 1e-6 - assert maxerr1 < 1e-7 - else: - assert err1 < 5e-8 - assert relerr1 < 8e-6 - assert maxerr1 < 1e-7 - assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.005 and relratio > 0.995 - assert maxratio < 1.005 and maxratio > 0.995 - elif dtype == torch.bfloat16: - if dim <= 512: - assert err1 < 6e-4 - assert relerr1 < 0.007 - assert maxerr1 < 0.015 - else: - assert err1 < 2e-4 - assert relerr1 < 0.002 - assert maxerr1 < 0.0012 - assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.04 and relratio > 0.96 - assert maxratio < 1.02 and maxratio > 0.98 - - @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed(): n = 32 * 10 @@ -1637,32 +1362,6 @@ def test_managed(): assert (A == 17 * (2**3)).sum().item() == n * n -@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) -def test_gemv_eye_4bit(storage_type, dtype, double_quant): - dims = 10 - torch.random.manual_seed(np.random.randint(0, 412424242)) - dims = get_test_dims(0, 8192, n=dims) - dims = [dim + (64 - (dim % 64)) for dim in dims] - # for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]: - for dim in dims: - A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device="cuda") - B = torch.eye(dim, dtype=dtype, device="cuda") - - qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant) - C3 = torch.matmul(A, B.t()) - C2 = bnb.matmul_4bit(A, qB.t(), state) - A.requires_grad = True - C1 = bnb.matmul_4bit(A, qB.t(), state) - - torch.testing.assert_close(A, C3) - torch.testing.assert_close(A, C1) - torch.testing.assert_close(A, C2) - # torch.testing.assert_close(A, C1, rtol=1e-5, atol=0.00001) - # torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) - - @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) @@ -1676,169 +1375,3 @@ def test_vector_quant(dim1, dim2, dim3): A1 = F.vectorwise_dequant(qA, SA) n = A1.numel() assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) - - -@pytest.mark.deprecated -def test_quantile_quantization(): - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - code = F.estimate_quantiles(A1) - C = F.quantize_no_absmax(A1, code) - A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1 - A2).mean().item() - assert diff < 0.0075 - - A1 = torch.rand(1024, 1024, device="cuda") - code = F.estimate_quantiles(A1) - C = F.quantize_no_absmax(A1, code) - A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0) - assert diff < 0.001 - - -@pytest.mark.deprecated -def test_dynamic_quantization(): - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize(A1) - A2 = F.dequantize(C, S) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - assert diff.mean().item() < 0.0135 - print(sum(diffs) / len(diffs)) - print(sum(reldiffs) / len(reldiffs)) - - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize(A1) - A2 = F.dequantize(C, S) - diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) - assert diff < 0.004 - - -@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) -@pytest.mark.deprecated -def test_percentile_clipping(gtype): - gnorm_vec1 = torch.zeros(100, device="cuda") - gnorm_vec2 = torch.zeros(100, device="cuda") - n = 4 - step = 0 - percentile = 5 - for i in range(k): - step += 1 - g = torch.randn(n, n, dtype=gtype, device="cuda") - gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) - assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 - - gnorm2 = torch.norm(g.float()) - if step == 1: - gnorm_vec1[:] = gnorm2 - else: - gnorm_vec1[step % 100] = gnorm2 - - vals, idx = torch.sort(gnorm_vec1) - clip1 = vals[percentile] - - torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) - torch.testing.assert_close(clip1, clip2) - torch.testing.assert_close(gnorm1, gnorm2) - - -@pytest.mark.parametrize("dim1", get_test_dims(2, 1024, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(2, 1024, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", [0], ids=id_formatter("dim3")) -@pytest.mark.parametrize("dims", [2], ids=id_formatter("dims")) -@pytest.mark.parametrize("dtype", [torch.int8], ids=describe_dtype) -@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) -@pytest.mark.parametrize("orderOut", ["col32", "col_turing", "col_ampere"], ids=id_formatter("orderOut")) -@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) -@pytest.mark.deprecated -def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - for i in range(k): - if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) - elif dims == 3: - A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) - - A.view(-1)[-1] = -1 - if transpose: - At = A.t().contiguous() - out1, S1 = F.nvidia_transform(At, to_order=orderOut) - else: - out1, S1 = F.nvidia_transform(A, to_order=orderOut) - out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) - - assert S1[0][0] == S2[0][0] - assert S1[0][1] == S2[0][1] - # print(out1) - # print(out2) - - torch.testing.assert_close(out1, out2) - - -@pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dtype", [torch.int8, torch.int32], ids=describe_dtype) -@pytest.mark.parametrize("orderA", ["row"], ids=id_formatter("orderA")) -@pytest.mark.parametrize("orderOut", ["col", "row", "col32"], ids=id_formatter("orderOut")) -@pytest.mark.parametrize("transpose", [False], ids=id_formatter("transpose")) -@pytest.mark.parametrize("dims", [2, 3], ids=id_formatter("dims")) -@pytest.mark.deprecated -def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - if dims == 3 and orderOut != "col32": - return - if dtype == torch.int32 and orderOut != "col32": - return - try: - func = F.get_transform_func(dtype, orderA, orderOut, transpose) - except ValueError as ve: - pytest.skip(str(ve)) # skip if not supported - - if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) - elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) - - out, S = F.nvidia_transform(A, to_order=orderOut) - - if orderOut == "row": - torch.testing.assert_close(A.flatten(), out.flatten()) - elif orderOut == "col": - torch.testing.assert_close(A.t().flatten(), out.flatten()) - elif orderOut == "col32": - if dims == 2: - n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) - elif dims == 3: - n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) - assert out.numel() == n - elif orderOut == "col_turing": - # 32 col 8 row tiles - n = (A.shape[0] + (8 - A.shape[0] % 8)) * (A.shape[1] + (32 - (A.shape[1] % 32))) - assert out.numel() == n - total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) - for row in range(A.shape[0]): - for col in range(A.shape[1]): - i = row * A.shape[1] - j = col - - coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile - offset = 32 * 8 * (rowtile + coltile) - col2 = col % 32 - row2 = (row % 8) * 32 - - assert A.flatten()[i + j] == A[row, col] - # assert A.flatten()[i+j] == out.flatten()[row2+col2] - # torch.testing.assert_close(A.flatten()[i+j], A[row, col]) - # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) - - if orderOut == "col32": - out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) - torch.testing.assert_close(A, out2) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index bc9e2600f..a9efa796f 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -8,8 +8,6 @@ import torch import bitsandbytes as bnb -from bitsandbytes import functional as F -from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout from bitsandbytes.nn.modules import Linear8bitLt from tests.helpers import ( TRUE_FALSE, @@ -18,28 +16,9 @@ torch_save_to_buffer, ) + # contributed by Alex Borzunov, see: # https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py - - -@pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), - reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", -) -def test_layout_exact_match(): - x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda() - for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"): - transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) - tile_indices = get_inverse_transform_indices(transform, tile_size) - cxb = transform(x) - - torch.cuda.synchronize() - restored_x = undo_layout(cxb, tile_indices) - torch.cuda.synchronize() - assert restored_x.is_contiguous() - assert torch.all(torch.eq(restored_x, x)) - - def test_linear_no_igemmlt(): linear = torch.nn.Linear(1024, 3072) x = torch.randn(3, 1024, dtype=torch.half) diff --git a/tests/test_ops.py b/tests/test_ops.py new file mode 100644 index 000000000..0e46ff9e0 --- /dev/null +++ b/tests/test_ops.py @@ -0,0 +1,221 @@ +from math import prod + +import pytest +import torch + +import bitsandbytes +from tests.helpers import TRUE_FALSE, id_formatter + + +class TestLLMInt8Ops: + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + def test_int8_linear_matmul(self, device): + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) + B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) + out = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) + + assert out.shape == (10, 30) + assert out.dtype == torch.int32 + assert out.device == A.device + + torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.default, (A, B)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + def test_int8_linear_matmul_out(self, device): + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) + B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) + + out = torch.empty((10, 30), dtype=torch.int32, device=device) + torch.ops.bitsandbytes.int8_linear_matmul.out(A, B, out) + + assert out.shape == (10, 30) + assert out.dtype == torch.int32 + assert out.device == A.device + + torch.library.opcheck(torch.ops.bitsandbytes.int8_linear_matmul.out, (A, B, out)) + + @pytest.mark.parametrize("threshold", [0.0, 6.0]) + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + def test_int8_vectorwise_quant(self, threshold, device): + if device == "cpu": + pytest.skip("CPU implementation is not available") + + A = torch.randn(10, 20, dtype=torch.float16, device=device) + A[1][0] = 1000.0 + + out_row, row_stats, outlier_cols = torch.ops.bitsandbytes.int8_vectorwise_quant(A, threshold=threshold) + + assert out_row.shape == (10, 20) + assert out_row.dtype == torch.int8 + assert out_row.device == A.device + assert row_stats.shape == (10,) + assert row_stats.dtype == torch.float32 + assert row_stats.device == A.device + + if threshold > 0.0: + assert outlier_cols is not None + assert outlier_cols.dim() == 1 + assert outlier_cols.shape[0] <= A.shape[1] + assert outlier_cols.device == A.device + else: + assert outlier_cols is None + + torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A,)) + + torch.library.opcheck(torch.ops.bitsandbytes.int8_vectorwise_quant, (A, threshold)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + def test_int8_mm_dequant(self, device): + A = torch.randint(-128, 127, (256, 256), dtype=torch.int32, device=device) + row_stats = torch.randn(256, dtype=torch.float32, device=device) + col_stats = torch.randn(256, dtype=torch.float32, device=device) + out = torch.ops.bitsandbytes.int8_mm_dequant(A, row_stats, col_stats) + + assert out.shape == A.shape + assert out.dtype == torch.float16 + assert out.device == A.device + + torch.library.opcheck(torch.ops.bitsandbytes.int8_mm_dequant, (A, row_stats, col_stats)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("has_bias", TRUE_FALSE) + def test_int8_scaled_mm(self, device, dtype, has_bias): + A = torch.randint(-128, 127, (10, 20), dtype=torch.int8, device=device) + B = torch.randint(-128, 127, (30, 20), dtype=torch.int8, device=device) + row_stats = torch.randn(10, dtype=torch.float32, device=device) + col_stats = torch.randn(30, dtype=torch.float32, device=device) + bias = torch.randn(30, dtype=dtype, device=device) if has_bias else None + out = torch.ops.bitsandbytes.int8_scaled_mm(A, B, row_stats, col_stats, bias=bias, dtype=dtype) + + assert out.shape == (10, 30) + assert out.dtype == dtype + assert out.device == A.device + + torch.library.opcheck(torch.ops.bitsandbytes.int8_scaled_mm, (A, B, row_stats, col_stats, bias, dtype)) + + +class TestInt8BlockwiseQuantOps: + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + def test_quantize_blockwise(self, device, dtype, blocksize): + if device == "cpu" and dtype != torch.float32: + pytest.skip("CPU implementation is only available for float32") + + code = bitsandbytes.functional.create_dynamic_map().to(device) + A = torch.randn(1024, 1024, dtype=dtype, device=device) + out, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize) + + assert out.shape == A.shape + assert out.dtype == torch.uint8 + assert out.device == A.device + + assert absmax.device == A.device + assert absmax.dtype == torch.float32 + + torch.library.opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + def test_dequantize_blockwise(self, device, dtype, blocksize): + if device == "cpu" and dtype != torch.float32: + pytest.skip("CPU implementation is only available for float32") + + A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device) + code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32) + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.randn((blocks,), device=device, dtype=torch.float32) + + out = torch.ops.bitsandbytes.dequantize_blockwise.default(A, absmax, code, blocksize, dtype) + + assert out.shape == A.shape + assert out.dtype == dtype + assert out.device == A.device + + torch.library.opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype)) + + +class Test4bitBlockwiseQuantOps: + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if device == "cpu": + pytest.skip("CPU implementation is not available") + + A = torch.randn(1024, 1024, dtype=dtype, device=device) + + out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype) + + assert out.device == A.device + assert out.dtype == storage_dtype + + assert absmax.device == A.device + assert absmax.dtype == torch.float32 + + torch.library.opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if device == "cpu": + pytest.skip("CPU implementation is not available") + + shape = (128, 128) + + n = prod(shape) + blocks = -(n // -blocksize) + quantized_shape = ((n + 1) // (storage_dtype.itemsize * 2), 1) + + A = ( + torch.randint(0, 255, ((n + 1) // 2,), dtype=torch.uint8, device=device) + .view(storage_dtype) + .reshape(quantized_shape) + .contiguous() + ) + + absmax = torch.randn((blocks,), dtype=torch.float32, device=device) + + out = torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype) + + assert out.device == A.device + assert out.shape == shape + + torch.library.opcheck( + torch.ops.bitsandbytes.dequantize_4bit.default, (A, absmax, blocksize, quant_type, shape, dtype) + ) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) + def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): + if device == "cpu": + pytest.skip("CPU implementation is not available") + + out_features = 1024 + in_features = 256 + + A = torch.randn((1, 1, in_features), dtype=dtype, device=device) + B = torch.randn((out_features, in_features), dtype=dtype, device=A.device) + B_q, absmax = torch.ops.bitsandbytes.quantize_4bit(B, blocksize, quant_type, storage_dtype) + code = bitsandbytes.functional.get_4bit_type(quant_type, device=A.device, blocksize=blocksize) + + out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize) + + assert out.device == A.device + assert out.dtype == dtype + assert out.shape == (1, 1, out_features) + assert out.isreal().all() + + torch.library.opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))