Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Custom Operator Integration #1544

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import research, utils

from . import _ops, research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
Expand All @@ -12,6 +13,8 @@
matmul_cublas,
mm_cublas,
)
from .backends.cpu import ops as cpu_ops
from .backends.cuda import ops as cuda_ops ## TODO: We would guard this for CUDA only
from .nn import modules
from .optim import adam

Expand Down
231 changes: 231 additions & 0 deletions bitsandbytes/_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from math import prod
from typing import Optional, Sequence, Tuple

import torch

_IS_TORCH_GTE_24 = False

if hasattr(torch.library, "register_fake"):
_IS_TORCH_GTE_24 = True
register_fake = torch.library.register_fake
register_kernel = torch.library.register_kernel
else:
# PyTorch <= 2.3
register_fake = torch.library.impl_abstract
register_kernel = torch.library.impl


# Higher level op: int8 matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_linear_dequant",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
)
Comment on lines +19 to +22
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I'm the main maintainer of custom operators in PyTorch. I'm curious -- why not use the torch.library.custom_op API instead of torch.library.define?

It would look something like:

@torch.library.custom_op("bitsandbytes::int8_linear_dequant", mutates_args=())
def int8_linear_dequant(A: Tensor, B: Tensor, row_stats: Tensor, col_stats: Tensor, bias: Optional[Tensor], dtype: torch.dtype) -> Tensor:
    raise NotImplementedError("")
 
@int8_linear_dequant.register_fake
 def _(
    A: torch.Tensor,
    B: torch.Tensor,
    row_stats: torch.Tensor,
    col_stats: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    dtype=torch.float16,
) -> torch.Tensor:
    shapeC = (*A.shape[:-1], B.shape[0])
    return torch.empty(shapeC, device=A.device, dtype=dtype)   

We generally encourage people to use torch.library.custom_op because the custom ops produced from it are guarded from various footguns when compared to torch.library.Library.define

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for the feedback :)

While the custom_op API does look to be convenient, there are two main reasons it was avoided:

  1. I'm not sure if we're ready to bump our minimum PyTorch requirement to 2.4.0+. With that said, we're not strictly opposed to that, however.
  2. I've heard from some others that there was significant overhead introduced with the use of custom_op:

I am curious, is it still reasonable to make use of infer_schema, and is that API available in torch < 2.4?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! It's not clear to me if we have fully fixed the performance issues, but I will check.
torch.library.infer_schema is only available in 2.5+. So if your goal is to support older pytorch versions you are doing the right thing



@register_fake("bitsandbytes::int8_linear_dequant")
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=dtype)


@register_kernel("bitsandbytes::int8_linear_dequant", None)
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
) -> torch.Tensor:
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul(A, B)
out = torch.ops.bitsandbytes.int8_mm_dequant(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
return out


# Define op
# TODO: mutable output arg as alias of return can be challenging;
# consider a separate op without aliased return:
# int8_linear_matmul_out(
# Tensor A, Tensor B, Tensor out, ScalarType dtype=int32
# ) -> ()
# return () instead of `None` for compatibility, see here: https://github.com/pytorch/pytorch/issues/125044
torch.library.define(
"bitsandbytes::int8_linear_matmul",
"(Tensor A, Tensor B, Tensor? out=None, ScalarType dtype=int32) -> Tensor",
)


@register_fake("bitsandbytes::int8_linear_matmul")
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
shapeC = (*A.shape[:-1], B.shape[0])
if out is None:
return torch.empty(shapeC, device=A.device, dtype=dtype)
return out


torch.library.define(
"bitsandbytes::int8_vectorwise_quant",
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)",
)


@register_fake("bitsandbytes::int8_vectorwise_quant")
def _(A: torch.Tensor, threshold=0.0):
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)

if threshold == 0.0:
return out_row, row_stats, None

outlier_cols = torch.library.get_ctx().new_dynamic_size()

return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64)


torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor")


@register_fake("bitsandbytes::int8_vectorwise_dequant")
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
torch._check(A.dtype == torch.int8, lambda: "A must be int8")
return torch.empty_like(A, dtype=torch.float32)


# Default PyTorch-native implementation
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
def _(A: torch.Tensor, stats: torch.Tensor):
# To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3


torch.library.define(
"bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? out=None, Tensor? bias=None) -> Tensor",
)


@register_fake("bitsandbytes::int8_mm_dequant")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
return torch.empty_like(A, dtype=dtype)


torch.library.define(
"bitsandbytes::int8_double_quant",
"(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)",
)


@register_fake("bitsandbytes::int8_double_quant")
def _(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
out_row = torch.empty_like(A, dtype=torch.int8)
out_col = torch.empty_like(A, dtype=torch.int8)
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32)
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32)
outlier_n = torch.library.get_ctx().new_dynamic_size()
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64)
return out_row, out_col, row_stats, col_stats, outlier_cols


torch.library.define(
"bitsandbytes::dequantize_4bit",
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor",
)


@register_fake("bitsandbytes::dequantize_4bit")
def _(
A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, shape: Sequence[int], dtype: torch.dtype
) -> torch.Tensor:
torch._check_is_size(blocksize)
return torch.empty(shape, dtype=dtype, device=A.device)


torch.library.define(
"bitsandbytes::quantize_4bit",
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)",
)


@register_fake("bitsandbytes::quantize_4bit")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)
return out, absmax


torch.library.define(
"bitsandbytes::dequantize_blockwise",
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor",
)


@register_fake("bitsandbytes::dequantize_blockwise")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
return torch.empty_like(A, dtype=dtype)


torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)")


@register_fake("bitsandbytes::quantize_blockwise")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)
return out, absmax


torch.library.define(
"bitsandbytes::gemv_4bit",
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor",
)


@register_fake("bitsandbytes::gemv_4bit")
def _(
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}")
torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
shape = (*A.shape[:-1], shapeB[0])
return torch.empty(shape, device=A.device, dtype=A.dtype)
24 changes: 11 additions & 13 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,20 +360,12 @@ def forward(
# we want to divide by 127. It's however more performant to multiply
# by the reciprocal.
outliers = state.CB[:, state.idx]
state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype)
state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t()
else:
subA = None

# 3. Int8 Matmul
out32 = F.int8_linear_matmul(CA, state.CB)

# Dequantize matmul result
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)
# 3. Int8 Matmul + Dequant + Bias
output = torch.ops.bitsandbytes.int8_linear_dequant(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype)

# 4. Mixed-precision decomposition matmul
if subA is not None and state.subB is not None:
Expand Down Expand Up @@ -423,8 +415,14 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
if req_gradB:
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))

gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t())
grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt)
grad_B = torch.ops.bitsandbytes.int8_linear_dequant(
Cgrad.t().contiguous(),
CAt.t(),
SCgradt,
SCAt,
dtype=torch.float16,
)

if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

Expand Down
Empty file.
Empty file.
61 changes: 61 additions & 0 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import ctypes as ct
from typing import Optional, Tuple

import torch

from bitsandbytes.functional import get_ptr

from ..._ops import register_kernel
from ...cextension import lib


@register_kernel("bitsandbytes::int8_linear_matmul", "cpu")
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
# Naive implementation: perform matmul in fp32
result = torch.matmul(A.float(), B.float().t()).to(torch.int32)
if out is not None:
result = out.copy_(result)
return result


@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")

n = A.numel()
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)

lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)

return out, absmax


@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")

out = torch.empty_like(A, dtype=dtype)

lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)

return out
Empty file.
Loading