-
Notifications
You must be signed in to change notification settings - Fork 676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Custom Operator Integration #1544
Open
matthewdouglas
wants to merge
19
commits into
main
Choose a base branch
from
customop-refactoring
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 14 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
6268912
Sketch out first custom op registration
matthewdouglas 04e1bc6
Add note
matthewdouglas d5df4c6
Merge branch 'main' into customop-refactoring
matthewdouglas 04482ff
Initial int8 op registration
matthewdouglas 2813571
Cleanup some deprecated functions.
matthewdouglas 4ad1d9e
Int8 ops updates; tests
matthewdouglas e9c79cf
Implement 4bit quant/dequant ops
matthewdouglas 9d0f459
Fix nested quant
matthewdouglas f360a08
cleanup
matthewdouglas 45ead33
Test improvements
matthewdouglas 6aeea81
Clean up and improve tests
matthewdouglas cbd1670
Add higher level custom op for int8 matmul + dequant + bias
matthewdouglas db07f4e
Add gemv 4bit custom op
matthewdouglas 23eba7a
Cleanup
matthewdouglas 2d5b2cc
Implement out kwarg overloads for custom ops
matthewdouglas 6172770
Update PyTorch minimum to 2.1
matthewdouglas 242c602
Deprecation updates
matthewdouglas 25368bc
Deprecation updates
matthewdouglas 32345e4
merge main
Titus-von-Koeller File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
from math import prod | ||
from typing import Optional, Sequence, Tuple | ||
|
||
import torch | ||
|
||
_IS_TORCH_GTE_24 = False | ||
|
||
if hasattr(torch.library, "register_fake"): | ||
_IS_TORCH_GTE_24 = True | ||
register_fake = torch.library.register_fake | ||
register_kernel = torch.library.register_kernel | ||
else: | ||
# PyTorch <= 2.3 | ||
register_fake = torch.library.impl_abstract | ||
register_kernel = torch.library.impl | ||
|
||
|
||
# Higher level op: int8 matmul + dequant + bias | ||
torch.library.define( | ||
"bitsandbytes::int8_linear_dequant", | ||
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::int8_linear_dequant") | ||
def _( | ||
A: torch.Tensor, | ||
B: torch.Tensor, | ||
row_stats: torch.Tensor, | ||
col_stats: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
dtype=torch.float16, | ||
) -> torch.Tensor: | ||
shapeC = (*A.shape[:-1], B.shape[0]) | ||
return torch.empty(shapeC, device=A.device, dtype=dtype) | ||
|
||
|
||
@register_kernel("bitsandbytes::int8_linear_dequant", None) | ||
def _( | ||
A: torch.Tensor, | ||
B: torch.Tensor, | ||
row_stats: torch.Tensor, | ||
col_stats: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
dtype=torch.float16, | ||
) -> torch.Tensor: | ||
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul(A, B) | ||
out = torch.ops.bitsandbytes.int8_mm_dequant(out_i32, row_stats, col_stats, dtype=dtype, bias=bias) | ||
return out | ||
|
||
|
||
# Define op | ||
# TODO: mutable output arg as alias of return can be challenging; | ||
# consider a separate op without aliased return: | ||
# int8_linear_matmul_out( | ||
# Tensor A, Tensor B, Tensor out, ScalarType dtype=int32 | ||
# ) -> () | ||
# return () instead of `None` for compatibility, see here: https://github.com/pytorch/pytorch/issues/125044 | ||
torch.library.define( | ||
"bitsandbytes::int8_linear_matmul", | ||
"(Tensor A, Tensor B, Tensor? out=None, ScalarType dtype=int32) -> Tensor", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::int8_linear_matmul") | ||
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): | ||
shapeC = (*A.shape[:-1], B.shape[0]) | ||
if out is None: | ||
return torch.empty(shapeC, device=A.device, dtype=dtype) | ||
return out | ||
|
||
|
||
torch.library.define( | ||
"bitsandbytes::int8_vectorwise_quant", | ||
"(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor?)", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::int8_vectorwise_quant") | ||
def _(A: torch.Tensor, threshold=0.0): | ||
out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) | ||
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32) | ||
|
||
if threshold == 0.0: | ||
return out_row, row_stats, None | ||
|
||
outlier_cols = torch.library.get_ctx().new_dynamic_size() | ||
|
||
return out_row, row_stats, A.new_empty(outlier_cols, dtype=torch.int64) | ||
|
||
|
||
torch.library.define("bitsandbytes::int8_vectorwise_dequant", "(Tensor A, Tensor stats) -> Tensor") | ||
|
||
|
||
@register_fake("bitsandbytes::int8_vectorwise_dequant") | ||
def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor: | ||
torch._check(A.dtype == torch.int8, lambda: "A must be int8") | ||
return torch.empty_like(A, dtype=torch.float32) | ||
|
||
|
||
# Default PyTorch-native implementation | ||
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None) | ||
def _(A: torch.Tensor, stats: torch.Tensor): | ||
# To dequantize we divide by 127, or multiply by the reciprocal. | ||
return A * stats.view(-1, 1) * 7.874015718698502e-3 | ||
|
||
|
||
torch.library.define( | ||
"bitsandbytes::int8_mm_dequant", | ||
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? out=None, Tensor? bias=None) -> Tensor", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::int8_mm_dequant") | ||
def _( | ||
A: torch.Tensor, | ||
row_stats: torch.Tensor, | ||
col_stats: torch.Tensor, | ||
dtype=torch.float16, | ||
out: Optional[torch.Tensor] = None, | ||
bias: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
torch._check(A.dtype == torch.int32, lambda: "A must be int32") | ||
return torch.empty_like(A, dtype=dtype) | ||
|
||
|
||
torch.library.define( | ||
"bitsandbytes::int8_double_quant", | ||
"(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::int8_double_quant") | ||
def _( | ||
A: torch.Tensor, | ||
col_stats: Optional[torch.Tensor] = None, | ||
row_stats: Optional[torch.Tensor] = None, | ||
out_col: Optional[torch.Tensor] = None, | ||
out_row: Optional[torch.Tensor] = None, | ||
threshold=0.0, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | ||
out_row = torch.empty_like(A, dtype=torch.int8) | ||
out_col = torch.empty_like(A, dtype=torch.int8) | ||
row_stats = torch.empty(prod(A.shape[:-1]), device=A.device, dtype=torch.float32) | ||
col_stats = torch.empty(A.shape[-1], device=A.device, dtype=torch.float32) | ||
outlier_n = torch.library.get_ctx().new_dynamic_size() | ||
outlier_cols = A.new_empty(outlier_n, dtype=torch.int64) | ||
return out_row, out_col, row_stats, col_stats, outlier_cols | ||
|
||
|
||
torch.library.define( | ||
"bitsandbytes::dequantize_4bit", | ||
"(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype) -> Tensor", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::dequantize_4bit") | ||
def _( | ||
A: torch.Tensor, absmax: torch.Tensor, blocksize: int, quant_type: str, shape: Sequence[int], dtype: torch.dtype | ||
) -> torch.Tensor: | ||
torch._check_is_size(blocksize) | ||
return torch.empty(shape, dtype=dtype, device=A.device) | ||
|
||
|
||
torch.library.define( | ||
"bitsandbytes::quantize_4bit", | ||
"(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::quantize_4bit") | ||
def _( | ||
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
torch._check_is_size(blocksize) | ||
|
||
n = A.numel() | ||
blocks = -(n // -blocksize) | ||
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) | ||
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) | ||
return out, absmax | ||
|
||
|
||
torch.library.define( | ||
"bitsandbytes::dequantize_blockwise", | ||
"(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype) -> Tensor", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::dequantize_blockwise") | ||
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: | ||
torch._check_is_size(blocksize) | ||
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") | ||
return torch.empty_like(A, dtype=dtype) | ||
|
||
|
||
torch.library.define("bitsandbytes::quantize_blockwise", "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)") | ||
|
||
|
||
@register_fake("bitsandbytes::quantize_blockwise") | ||
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]: | ||
torch._check_is_size(blocksize) | ||
n = A.numel() | ||
blocks = -(n // -blocksize) | ||
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) | ||
out = torch.empty_like(A, dtype=torch.uint8) | ||
return out, absmax | ||
|
||
|
||
torch.library.define( | ||
"bitsandbytes::gemv_4bit", | ||
"(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize) -> Tensor", | ||
) | ||
|
||
|
||
@register_fake("bitsandbytes::gemv_4bit") | ||
def _( | ||
A: torch.Tensor, B: torch.Tensor, shapeB: Sequence[int], absmax: torch.Tensor, code: torch.Tensor, blocksize: int | ||
) -> torch.Tensor: | ||
torch._check_is_size(blocksize) | ||
torch._check(A.numel() == A.size(-1), lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}") | ||
torch._check( | ||
A.dtype in [torch.float16, torch.bfloat16, torch.float32], | ||
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", | ||
) | ||
torch._check( | ||
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], | ||
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", | ||
) | ||
shape = (*A.shape[:-1], shapeB[0]) | ||
return torch.empty(shape, device=A.device, dtype=A.dtype) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import ctypes as ct | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
from bitsandbytes.functional import get_ptr | ||
|
||
from ..._ops import register_kernel | ||
from ...cextension import lib | ||
|
||
|
||
@register_kernel("bitsandbytes::int8_linear_matmul", "cpu") | ||
def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32): | ||
# Naive implementation: perform matmul in fp32 | ||
result = torch.matmul(A.float(), B.float().t()).to(torch.int32) | ||
if out is not None: | ||
result = out.copy_(result) | ||
return result | ||
|
||
|
||
@register_kernel("bitsandbytes::quantize_blockwise", "cpu") | ||
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor, torch.Tensor]: | ||
torch._check_is_size(blocksize) | ||
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}") | ||
|
||
n = A.numel() | ||
blocks = -(n // -blocksize) | ||
|
||
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) | ||
out = torch.empty_like(A, dtype=torch.uint8) | ||
|
||
lib.cquantize_blockwise_cpu_fp32( | ||
get_ptr(code), | ||
get_ptr(A), | ||
get_ptr(absmax), | ||
get_ptr(out), | ||
ct.c_longlong(blocksize), | ||
ct.c_longlong(n), | ||
) | ||
|
||
return out, absmax | ||
|
||
|
||
@register_kernel("bitsandbytes::dequantize_blockwise", "cpu") | ||
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: | ||
torch._check_is_size(blocksize) | ||
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") | ||
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}") | ||
|
||
out = torch.empty_like(A, dtype=dtype) | ||
|
||
lib.cdequantize_blockwise_cpu_fp32( | ||
get_ptr(code), | ||
get_ptr(A), | ||
get_ptr(absmax), | ||
get_ptr(out), | ||
ct.c_longlong(blocksize), | ||
ct.c_longlong(A.numel()), | ||
) | ||
|
||
return out |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! I'm the main maintainer of custom operators in PyTorch. I'm curious -- why not use the torch.library.custom_op API instead of torch.library.define?
It would look something like:
We generally encourage people to use torch.library.custom_op because the custom ops produced from it are guarded from various footguns when compared to torch.library.Library.define
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Thanks for the feedback :)
While the
custom_op
API does look to be convenient, there are two main reasons it was avoided:custom_op
:torch.Library.{define, impl}
pytorch/pytorch#139500I am curious, is it still reasonable to make use of
infer_schema
, and is that API available in torch < 2.4?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback! It's not clear to me if we have fully fixed the performance issues, but I will check.
torch.library.infer_schema is only available in 2.5+. So if your goal is to support older pytorch versions you are doing the right thing