Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c98c3b3

Browse files
authoredFeb 11, 2025··
Add Int8GEMM (#2)
* add GEMMInt8 * cache context * update benchmark
1 parent 387d9c0 commit c98c3b3

File tree

7 files changed

+503
-47
lines changed

7 files changed

+503
-47
lines changed
 

‎README.md

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# TorchSharp.BitsAndBytes
1+
# TorchSharp.BitsAndBytes
22
The `TorchSharp.BitsAndBytes` is a C# binding library for [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) library from Huggingface. It provides 4Bit and 8Bit quantization for TorchSharp models.
33

44
## Usage
@@ -17,4 +17,25 @@ int blockSize = 64; // can be [64, 128, 256, 512, 1024]
1717
var dequantizedTensor = BitsAndByteUtils.Dequantize4Bit(quantiedTensor, absMax, input.dtype, quantizedDType, n, input.shape, blockSize);
1818
```
1919

20-
For more examples, please refer to the *incoming benchmark* project.
20+
For more examples, please refer to the [Benchmark](#Benchmark) section.
21+
22+
## Benchmark
23+
```
24+
25+
BenchmarkDotNet v0.14.0, Windows 11 (10.0.26100.3037)
26+
Intel Core i9-14900K, 1 CPU, 32 logical and 24 physical cores
27+
.NET SDK 9.0.102
28+
[Host] : .NET 8.0.12 (8.0.1224.60305), X64 RyuJIT AVX2
29+
DefaultJob : .NET 8.0.12 (8.0.1224.60305), X64 RyuJIT AVX2
30+
31+
32+
```
33+
| Method | Mean | Error | StdDev |
34+
|--------------- |------------:|----------:|----------:|
35+
| Quantize4Bit | 536.35 μs | 12.164 μs | 35.290 μs |
36+
| Dequantize4Bit | 2,257.89 μs | 44.542 μs | 51.294 μs |
37+
| GEMV_4Bit_FP4 | 84.16 μs | 1.673 μs | 3.223 μs |
38+
| GEMV_4Bit_NF4 | 82.69 μs | 4.329 μs | 12.629 μs |
39+
| GEMV_FP32 | 49.59 μs | 0.975 μs | 2.035 μs |
40+
| GEMM_INT8 | 2,994.86 μs | 12.144 μs | 11.360 μs |
41+
| GEMM_FP32 | 4,495.49 μs | 35.264 μs | 32.986 μs |

‎TorchSharp.BitsAndBytes.Benchmark/CudaBenchmark.cs

+28-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace TorchSharp.BitsAndBytes.Benchmark;
1111

12-
public class CudaBenchmark : IDisposable
12+
public class CudaBenchmark
1313
{
1414
private Tensor a1;
1515
private Tensor b;
@@ -19,7 +19,7 @@ public class CudaBenchmark : IDisposable
1919

2020
public CudaBenchmark()
2121
{
22-
a1 = torch.rand(new long[] { dim * 4, dim }, dtype: ScalarType.Float32).cuda();
22+
a1 = torch.rand([dim * 4, dim], dtype: ScalarType.Float32).cuda();
2323
}
2424

2525
private torch.Tensor quantizedTensor;
@@ -53,13 +53,37 @@ public void GEMV_4Bit_FP4()
5353
}
5454

5555
[Benchmark]
56-
public void GEMV_FP32()
56+
public void GEMV_4Bit_NF4()
5757
{
5858
using var input = torch.rand(new long[] { 1, dim }, dtype: ScalarType.Float32).cuda();
59+
using var result = BitsAndByteUtils.Gemv4Bit(input, quantizedTensor, [4 * dim, dim], absMax, blockSize, "nf4");
60+
}
61+
62+
[Benchmark]
63+
public void GEMV_FP32()
64+
{
65+
using var input = torch.rand([1, dim], dtype: ScalarType.Float32).cuda();
5966
using var result = torch.matmul(input, b.T);
6067
}
6168

62-
public void Dispose()
69+
[Benchmark]
70+
public void GEMM_INT8()
71+
{
72+
using var input = torch.randint(-128, 127, new long[] { 1, dim }, dtype: ScalarType.Int8).cuda();
73+
using var weight = torch.randint(-128, 127, new long[] { dim, dim }, dtype: ScalarType.Int8).cuda();
74+
using var result = Function.Int8GEMM(input, weight);
75+
}
76+
77+
[Benchmark]
78+
public void GEMM_FP32()
79+
{
80+
using var input = torch.randint(-128, 127, new long[] { 1, dim }, dtype: ScalarType.Float32).cuda();
81+
using var weight = torch.randint(-128, 127, new long[] { dim, dim }, dtype: ScalarType.Float32).cuda();
82+
using var result = torch.matmul(input, weight);
83+
}
84+
85+
[GlobalCleanup]
86+
public void Cleanup()
6387
{
6488
a1.Dispose();
6589
b.Dispose();
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
using BenchmarkDotNet.Running;
22
using TorchSharp.BitsAndBytes.Benchmark;
3-
43
BenchmarkRunner.Run<CudaBenchmark>();

‎TorchSharp.BitsAndBytes.Tests/BitsAndBytes4BitTests.cs

+124
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,88 @@ public void Test4BitQuant(ScalarType inputDType, string quantizedDType, int bloc
6969
Assert.True(avg.First() <= 0.2);
7070
}
7171

72+
[CudaTheory]
73+
[InlineData(32, 1, false, false, 16)]
74+
[InlineData(32, 1, false, true, 16)]
75+
[InlineData(32, 1, true, false, 16)]
76+
[InlineData(32, 1, true, true, 16)]
77+
[InlineData(64, 1, true, true, 16)]
78+
[InlineData(128, 1, true, true, 16)]
79+
[InlineData(512, 1, true, true, 16)]
80+
[InlineData(32, 1, true, true, 512)]
81+
[InlineData(32, 16, false, false, 16)]
82+
[InlineData(32, 16, false, true, 16)]
83+
[InlineData(32, 8, true, false, 16)]
84+
[InlineData(32, 4, true, true, 16)]
85+
[InlineData(128, 32, true, true, 16)]
86+
[InlineData(512, 32, true, true, 16)]
87+
[InlineData(32, 4, true, true, 512)]
88+
public void TestInt8GEMM(int hiddenDim, int batchDim, bool transposeInput, bool transposeWeight, int seqDim)
89+
{
90+
// 2-D input
91+
foreach (int i in Enumerable.Range(0, 20))
92+
{
93+
long[] inputShape = !transposeInput ? [batchDim, hiddenDim] : [hiddenDim, batchDim];
94+
var outputChannel = 32 * new Random().Next(1, 10);
95+
long[] weightShape = transposeWeight ? [outputChannel, hiddenDim] : [hiddenDim, outputChannel];
96+
97+
using var input = torch.randint(-128, 127, inputShape, ScalarType.Int8).cuda();
98+
using var weight = torch.randint(-128, 127, weightShape, ScalarType.Int8).cuda();
99+
using var baseline = (transposeInput, transposeWeight) switch
100+
{
101+
(false, false) => torch.matmul(input.to_type(ScalarType.Float32), weight.to_type(ScalarType.Float32)),
102+
(false, true) => torch.matmul(input.to_type(ScalarType.Float32), weight.to_type(ScalarType.Float32).t()),
103+
(true, false) => torch.matmul(input.to_type(ScalarType.Float32).t(), weight.to_type(ScalarType.Float32)),
104+
(true, true) => torch.matmul(input.to_type(ScalarType.Float32).t(), weight.to_type(ScalarType.Float32).t()),
105+
};
106+
using var result = (transposeInput, transposeWeight) switch
107+
{
108+
(false, false) => Function.Int8GEMM(input, weight),
109+
(false, true) => Function.Int8GEMM(input, weight.t()),
110+
(true, false) => Function.Int8GEMM(input.t(), weight),
111+
(true, true) => Function.Int8GEMM(input.t(), weight.t()),
112+
};
113+
114+
var diff = baseline - result.to_type(ScalarType.Float32);
115+
var avg = diff.abs().mean().data<float>();
116+
117+
Assert.True(avg[0] <= 1e-5);
118+
}
119+
120+
// 3-dim input
121+
foreach (int i in Enumerable.Range(0, 20))
122+
{
123+
if (transposeInput)
124+
{
125+
// skip 3-dim input with transposeInput = true
126+
continue;
127+
}
128+
long[] inputShape = [batchDim, seqDim, hiddenDim];
129+
var outputChannel = 32 * new Random().Next(1, 10);
130+
long[] weightShape = transposeWeight ? [outputChannel, hiddenDim] : [hiddenDim, outputChannel];
131+
132+
using var input = torch.randint(-128, 127, inputShape, ScalarType.Int8).cuda();
133+
using var weight = torch.randint(-128, 127, weightShape, ScalarType.Int8).cuda();
134+
using var baseline = (transposeInput, transposeWeight) switch
135+
{
136+
(false, false) => torch.matmul(input.to_type(ScalarType.Float32), weight.to_type(ScalarType.Float32)),
137+
(false, true) => torch.matmul(input.to_type(ScalarType.Float32), weight.to_type(ScalarType.Float32).t()),
138+
_ => throw new NotImplementedException()
139+
};
140+
using var result = (transposeInput, transposeWeight) switch
141+
{
142+
(false, false) => Function.Int8GEMM(input, weight),
143+
(false, true) => Function.Int8GEMM(input, weight.t()),
144+
_ => throw new NotImplementedException()
145+
};
146+
147+
var diff = baseline - result.to_type(ScalarType.Float32);
148+
var avg = diff.abs().mean().data<float>();
149+
150+
Assert.True(avg[0] <= 1e-5);
151+
}
152+
}
153+
72154
[CudaTheory]
73155
[InlineData(ScalarType.Float32, "fp4", 64, 1024)]
74156
[InlineData(ScalarType.Float32, "nf4", 64, 1024)]
@@ -174,4 +256,46 @@ public void TestGemv4Bit3D128(ScalarType dtype, string quantizedDType, int block
174256
Assert.Equal(1, avg.Count);
175257
Assert.True(avg.First() == 0);
176258
}
259+
260+
[Fact]
261+
public void TestCheckMatmul_ValidInputs()
262+
{
263+
var A = torch.randint(0, 10, new long[] { 2, 3 }, ScalarType.Int8);
264+
var B = torch.randint(0, 10, new long[] { 3, 2 }, ScalarType.Int8);
265+
266+
var result = BitsAndByteUtils.CheckMatmul(A, B, false, false, ScalarType.Int8);
267+
268+
Assert.Equal([2, 2], result);
269+
}
270+
271+
[Fact]
272+
public void TestCheckMatmul_InvalidInputs()
273+
{
274+
var A = torch.randint(0, 10, new long[] { 2, 3 }, ScalarType.Int8);
275+
var B = torch.randint(0, 10, new long[] { 2, 2 }, ScalarType.Int8);
276+
277+
Assert.Throws<ArgumentException>(() => BitsAndByteUtils.CheckMatmul(A, B, false, false, ScalarType.Int8));
278+
}
279+
280+
[Fact]
281+
public void TestCheckMatmul_TransposedInputs()
282+
{
283+
var A = torch.randint(0, 10, new long[] { 3, 2 }, ScalarType.Int8);
284+
var B = torch.randint(0, 10, new long[] { 3, 2 }, ScalarType.Int8);
285+
286+
var result = BitsAndByteUtils.CheckMatmul(A, B, true, false, ScalarType.Int8);
287+
288+
Assert.Equal([2, 2], result);
289+
}
290+
291+
[Fact]
292+
public void TestCheckMatmul_NullOutput()
293+
{
294+
var A = torch.randint(0, 10, new long[] { 2, 3 }, ScalarType.Int8);
295+
var B = torch.randint(0, 10, new long[] { 3, 2 }, ScalarType.Int8);
296+
297+
var result = BitsAndByteUtils.CheckMatmul(A, B, false, false, ScalarType.Int8);
298+
299+
Assert.Equal([2, 2], result);
300+
}
177301
}

‎TorchSharp.BitsAndBytes/BitsAndByteUtils.cs

+143-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ public static Tensor Dequantize4Bit(
195195
return dequantizedTensor;
196196
}
197197

198-
198+
199199

200200
public static Tensor Get4BitType(string typename, string device = "cuda", int blocksize = 64)
201201
{
@@ -421,4 +421,146 @@ public static torch.Tensor CreateDynamicMap(bool signed = true, int maxExponentB
421421
data.Sort();
422422
return torch.tensor(data.ToArray());
423423
}
424+
425+
public static int[] CheckMatmul(Tensor A, Tensor B, bool transposed_A, bool transposed_B, ScalarType expectedType = ScalarType.Int8)
426+
{
427+
if (A.dtype != expectedType || B.dtype != expectedType)
428+
{
429+
throw new ArgumentException($"Expected {expectedType} input tensors A and B, but got {A.dtype} and {B.dtype}");
430+
}
431+
432+
var sA = A.IntShape();
433+
var sB = B.IntShape();
434+
var tA = transposed_A;
435+
var tB = transposed_B;
436+
437+
bool correct = true;
438+
439+
if (sA.Length == 2 && sB.Length == 2)
440+
{
441+
if (!tA && !tB && A.shape[1] != B.shape[0])
442+
{
443+
correct = false;
444+
}
445+
else if (tA && !tB && A.shape[0] != B.shape[0])
446+
{
447+
correct = false;
448+
}
449+
else if (tA && tB && A.shape[0] != B.shape[1])
450+
{
451+
correct = false;
452+
}
453+
else if (!tA && tB && A.shape[1] != B.shape[1])
454+
{
455+
correct = false;
456+
}
457+
}
458+
else if (sA.Length == 3 && sB.Length == 2)
459+
{
460+
if (!tA && !tB && A.shape[2] != B.shape[0])
461+
{
462+
correct = false;
463+
}
464+
else if (tA && !tB && A.shape[1] != B.shape[0])
465+
{
466+
correct = false;
467+
}
468+
else if (tA && tB && A.shape[1] != B.shape[1])
469+
{
470+
correct = false;
471+
}
472+
else if (!tA && tB && A.shape[2] != B.shape[1])
473+
{
474+
correct = false;
475+
}
476+
}
477+
else if (sA.Length == 3 && sB.Length == 3)
478+
{
479+
if (!tA && !tB && A.shape[2] != B.shape[1])
480+
{
481+
correct = false;
482+
}
483+
else if (tA && !tB && A.shape[1] != B.shape[1])
484+
{
485+
correct = false;
486+
}
487+
else if (tA && tB && A.shape[1] != B.shape[2])
488+
{
489+
correct = false;
490+
}
491+
else if (!tA && tB && A.shape[2] != B.shape[2])
492+
{
493+
correct = false;
494+
}
495+
}
496+
497+
int[] outShape = default!;
498+
499+
if (sA.Length == 2 && sB.Length == 2)
500+
{
501+
if (!tA && !tB)
502+
{
503+
outShape = [sA[0], sB[1]];
504+
}
505+
else if (tA && tB)
506+
{
507+
outShape = [sA[1], sB[0]];
508+
}
509+
else if (tA && !tB)
510+
{
511+
outShape = [sA[1], sB[1]];
512+
}
513+
else if (!tA && tB)
514+
{
515+
outShape = [sA[0], sB[0]];
516+
}
517+
}
518+
else if (sA.Length == 3 && sB.Length == 2)
519+
{
520+
if (!tA && !tB)
521+
{
522+
outShape = [sA[0], sA[1], sB[1]];
523+
}
524+
else if (tA && tB)
525+
{
526+
outShape = [sA[0], sA[2], sB[0]];
527+
}
528+
else if (tA && !tB)
529+
{
530+
outShape = [sA[0], sA[2], sB[1]];
531+
}
532+
else if (!tA && tB)
533+
{
534+
outShape = [sA[0], sA[1], sB[0]];
535+
}
536+
}
537+
else if (sA.Length == 3 && sB.Length == 3)
538+
{
539+
if (!tA && !tB)
540+
{
541+
outShape = [sA[0], sA[1], sB[2]];
542+
}
543+
else if (tA && tB)
544+
{
545+
outShape = [sA[0], sA[2], sB[1]];
546+
}
547+
else if (tA && !tB)
548+
{
549+
outShape = [sA[0], sA[2], sB[2]];
550+
}
551+
else if (!tA && tB)
552+
{
553+
outShape = [sA[0], sA[1], sB[1]];
554+
}
555+
}
556+
557+
if (!correct)
558+
{
559+
throw new ArgumentException(
560+
$"Tensor dimensions incorrect for matrix multiplication: A x B: {sA.ToArray()} x {sB.ToArray()} with transpose for A x B: {tA} x {tB}."
561+
);
562+
}
563+
564+
return outShape;
565+
}
424566
}

‎TorchSharp.BitsAndBytes/BitsAndBytes.cs

+21-39
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,6 @@ public static class BitsAndBytesCudaNative
1010
{
1111
private const string DllName = "libbitsandbytes_cuda121";
1212

13-
/// <summary>
14-
/// Represents the CUDA __nv_bfloat16 type
15-
/// </summary>
16-
[StructLayout(LayoutKind.Sequential)]
17-
public struct NvBFloat16
18-
{
19-
public ushort Value;
20-
}
21-
22-
[DllImport(DllName)]
23-
public static extern void cdequantize_blockwise_fp32(
24-
IntPtr code, // float*
25-
IntPtr A, // float*
26-
IntPtr absmax, // float*
27-
IntPtr output, // unsigned char*
28-
int blocksize,
29-
int n, // total size
30-
IntPtr stream);
31-
32-
[DllImport(DllName)]
33-
public static extern void cdequantize_blockwise_fp16(
34-
IntPtr code, // float*
35-
IntPtr A, // float*
36-
IntPtr absmax, // float*
37-
IntPtr output, // unsigned char*
38-
int blocksize,
39-
int n, // total size
40-
IntPtr stream);
41-
42-
[DllImport(DllName)]
43-
public static extern void cdequantize_blockwise_bf16(
44-
IntPtr code, // float*
45-
IntPtr A, // float*
46-
IntPtr absmax, // float*
47-
IntPtr output, // unsigned char*
48-
int blocksize,
49-
int n, // total size
50-
IntPtr stream);
51-
5213
[DllImport(DllName)]
5314
public static extern void cdequantize_blockwise_fp32_fp4(
5415
IntPtr code, // float*
@@ -238,4 +199,25 @@ public static extern void dequantize(
238199
int size,
239200
IntPtr stream // cudaStream_t
240201
);
202+
203+
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
204+
public static extern void cigemm(
205+
IntPtr context,
206+
bool transposeA,
207+
bool transposeB,
208+
int m,
209+
int n,
210+
int k,
211+
IntPtr A, // input
212+
IntPtr B, // weight
213+
IntPtr C, // output
214+
int lda,
215+
int ldb,
216+
int ldc);
217+
218+
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
219+
public static extern IntPtr get_context();
220+
221+
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
222+
public static extern IntPtr get_cusparse();
241223
}

‎TorchSharp.BitsAndBytes/Function.cs

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using static TorchSharp.torch;
7+
8+
namespace TorchSharp.BitsAndBytes;
9+
10+
public class Function
11+
{
12+
private static readonly Lazy<Dictionary<int, IntPtr>> _context = new(() => new Dictionary<int, IntPtr>());
13+
/// <summary>
14+
/// Integer General Matrix Multiplication (IGEMM) for 8-bit integer data types.
15+
/// </summary>
16+
/// <param name="input"></param>
17+
/// <param name="weight"></param>
18+
/// <param name="transposeWeight"></param>
19+
/// <param name="transposeInput"></param>
20+
/// <returns></returns>
21+
public static Tensor Int8GEMM(
22+
Tensor input,
23+
Tensor weight,
24+
bool transposeWeight = false,
25+
bool transposeInput = false)
26+
{
27+
var sout = BitsAndByteUtils.CheckMatmul(input, weight, transposeWeight, transposeInput);
28+
var result = torch.zeros((long[])[.. sout], dtype: torch.int32, device: input.device);
29+
if (input.shape.Length == 3 && weight.shape.Length == 3)
30+
{
31+
if (input.shape[0] == weight.shape[0] && input.shape[2] == weight.shape[1])
32+
{
33+
throw new NotImplementedException();
34+
}
35+
}
36+
37+
var inputShape = input.IntShape().ToArray();
38+
var weightShape = weight.IntShape().ToArray();
39+
if (transposeInput && inputShape.Length == 2)
40+
{
41+
inputShape = [inputShape[1], inputShape[0]];
42+
}
43+
else if (transposeInput && inputShape.Length == 3)
44+
{
45+
inputShape = [inputShape[0], inputShape[2], inputShape[0]];
46+
}
47+
if (transposeWeight && weightShape.Length == 2)
48+
{
49+
weightShape = [weightShape[1], weightShape[0]];
50+
}
51+
else if (transposeWeight && weightShape.Length == 3)
52+
{
53+
weightShape = [weightShape[0], weightShape[2], weightShape[0]];
54+
}
55+
// this is a mess: cuBLAS expect column major, but PyTorch is row major.
56+
// So to perform the matrix multiplication, we have to treat A, B, and C matrices
57+
// (transpose of row major is column major)
58+
// This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
59+
60+
// matrices in the input arguments for cuBLAS
61+
// column major: A @ B = C: [m, k] @ [k, n] = [m, n]
62+
// row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
63+
// column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
64+
int m = 0, n = 0, k = 0, lda = 0, ldb = 0, ldc = 0;
65+
66+
if (weightShape.Length == 2)
67+
{
68+
if (weight.stride(0) == weight.shape[1])
69+
{
70+
transposeWeight = false;
71+
}
72+
else if (weight.stride(1) == weight.shape[0])
73+
{
74+
transposeWeight = true;
75+
}
76+
if (input.shape.Length == 2)
77+
{
78+
if (input.stride(0) == input.shape[1])
79+
{
80+
transposeInput = false;
81+
}
82+
else if (input.stride(1) == input.shape[0])
83+
{
84+
transposeInput = true;
85+
}
86+
}
87+
else
88+
{
89+
if (input.stride(1) == input.shape[2])
90+
{
91+
transposeInput = false;
92+
}
93+
else if (input.stride(2) == input.shape[1])
94+
{
95+
transposeInput = true;
96+
}
97+
}
98+
99+
if (inputShape.Length == 2)
100+
{
101+
n = inputShape[0];
102+
ldb = (int)input.stride(transposeInput ? 1 : 0);
103+
}
104+
else if (inputShape.Length == 3 && weightShape.Length == 2)
105+
{
106+
n = inputShape[0] * inputShape[1];
107+
ldb = inputShape[2];
108+
}
109+
110+
m = weightShape[1];
111+
k = weightShape[0];
112+
lda = (int)weight.stride(transposeWeight ? 1 : 0);
113+
ldc = weightShape[1];
114+
}
115+
else if (weightShape.Length == 3)
116+
{
117+
// special case
118+
if (!(inputShape[0] == weightShape[0] && inputShape[1] == weightShape[1]))
119+
{
120+
throw new ArgumentException($"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {inputShape} x {weightShape}");
121+
}
122+
123+
transposeInput = true;
124+
transposeWeight = false;
125+
m = weightShape[2];
126+
n = inputShape[2];
127+
k = weightShape[0] * weightShape[1];
128+
129+
lda = m;
130+
ldb = inputShape[2];
131+
ldc = m;
132+
}
133+
134+
IntPtr context;
135+
if (_context.Value.TryGetValue(input.device_index, out var ctx))
136+
{
137+
context = ctx;
138+
}
139+
else
140+
{
141+
context = BitsAndBytesCudaNative.get_context();
142+
_context.Value[input.device_index] = context;
143+
}
144+
145+
var A = LibTorchNativeMethod.THSStorage_data_ptr(input.Handle);
146+
var B = LibTorchNativeMethod.THSStorage_data_ptr(weight.Handle);
147+
var C = LibTorchNativeMethod.THSStorage_data_ptr(result.Handle);
148+
BitsAndBytesCudaNative.cigemm(
149+
context: context,
150+
transposeA: transposeWeight, // cuBLAS expects column major, but PyTorch is row major
151+
transposeB: transposeInput, // So we have to transpose A and B
152+
m: m,
153+
n: n,
154+
k: k,
155+
A: B, // out_T = B_T @ A_T
156+
B: A,
157+
C: C,
158+
lda: lda,
159+
ldb: ldb,
160+
ldc: ldc);
161+
return result;
162+
163+
}
164+
}

0 commit comments

Comments
 (0)
Please sign in to comment.