@@ -69,6 +69,88 @@ public void Test4BitQuant(ScalarType inputDType, string quantizedDType, int bloc
69
69
Assert . True ( avg . First ( ) <= 0.2 ) ;
70
70
}
71
71
72
+ [ CudaTheory ]
73
+ [ InlineData ( 32 , 1 , false , false , 16 ) ]
74
+ [ InlineData ( 32 , 1 , false , true , 16 ) ]
75
+ [ InlineData ( 32 , 1 , true , false , 16 ) ]
76
+ [ InlineData ( 32 , 1 , true , true , 16 ) ]
77
+ [ InlineData ( 64 , 1 , true , true , 16 ) ]
78
+ [ InlineData ( 128 , 1 , true , true , 16 ) ]
79
+ [ InlineData ( 512 , 1 , true , true , 16 ) ]
80
+ [ InlineData ( 32 , 1 , true , true , 512 ) ]
81
+ [ InlineData ( 32 , 16 , false , false , 16 ) ]
82
+ [ InlineData ( 32 , 16 , false , true , 16 ) ]
83
+ [ InlineData ( 32 , 8 , true , false , 16 ) ]
84
+ [ InlineData ( 32 , 4 , true , true , 16 ) ]
85
+ [ InlineData ( 128 , 32 , true , true , 16 ) ]
86
+ [ InlineData ( 512 , 32 , true , true , 16 ) ]
87
+ [ InlineData ( 32 , 4 , true , true , 512 ) ]
88
+ public void TestInt8GEMM ( int hiddenDim , int batchDim , bool transposeInput , bool transposeWeight , int seqDim )
89
+ {
90
+ // 2-D input
91
+ foreach ( int i in Enumerable . Range ( 0 , 20 ) )
92
+ {
93
+ long [ ] inputShape = ! transposeInput ? [ batchDim , hiddenDim ] : [ hiddenDim , batchDim ] ;
94
+ var outputChannel = 32 * new Random ( ) . Next ( 1 , 10 ) ;
95
+ long [ ] weightShape = transposeWeight ? [ outputChannel , hiddenDim ] : [ hiddenDim , outputChannel ] ;
96
+
97
+ using var input = torch . randint ( - 128 , 127 , inputShape , ScalarType . Int8 ) . cuda ( ) ;
98
+ using var weight = torch . randint ( - 128 , 127 , weightShape , ScalarType . Int8 ) . cuda ( ) ;
99
+ using var baseline = ( transposeInput , transposeWeight ) switch
100
+ {
101
+ ( false , false ) => torch . matmul ( input . to_type ( ScalarType . Float32 ) , weight . to_type ( ScalarType . Float32 ) ) ,
102
+ ( false , true ) => torch . matmul ( input . to_type ( ScalarType . Float32 ) , weight . to_type ( ScalarType . Float32 ) . t ( ) ) ,
103
+ ( true , false ) => torch . matmul ( input . to_type ( ScalarType . Float32 ) . t ( ) , weight . to_type ( ScalarType . Float32 ) ) ,
104
+ ( true , true ) => torch . matmul ( input . to_type ( ScalarType . Float32 ) . t ( ) , weight . to_type ( ScalarType . Float32 ) . t ( ) ) ,
105
+ } ;
106
+ using var result = ( transposeInput , transposeWeight ) switch
107
+ {
108
+ ( false , false ) => Function . Int8GEMM ( input , weight ) ,
109
+ ( false , true ) => Function . Int8GEMM ( input , weight . t ( ) ) ,
110
+ ( true , false ) => Function . Int8GEMM ( input . t ( ) , weight ) ,
111
+ ( true , true ) => Function . Int8GEMM ( input . t ( ) , weight . t ( ) ) ,
112
+ } ;
113
+
114
+ var diff = baseline - result . to_type ( ScalarType . Float32 ) ;
115
+ var avg = diff . abs ( ) . mean ( ) . data < float > ( ) ;
116
+
117
+ Assert . True ( avg [ 0 ] <= 1e-5 ) ;
118
+ }
119
+
120
+ // 3-dim input
121
+ foreach ( int i in Enumerable . Range ( 0 , 20 ) )
122
+ {
123
+ if ( transposeInput )
124
+ {
125
+ // skip 3-dim input with transposeInput = true
126
+ continue ;
127
+ }
128
+ long [ ] inputShape = [ batchDim , seqDim , hiddenDim ] ;
129
+ var outputChannel = 32 * new Random ( ) . Next ( 1 , 10 ) ;
130
+ long [ ] weightShape = transposeWeight ? [ outputChannel , hiddenDim ] : [ hiddenDim , outputChannel ] ;
131
+
132
+ using var input = torch . randint ( - 128 , 127 , inputShape , ScalarType . Int8 ) . cuda ( ) ;
133
+ using var weight = torch . randint ( - 128 , 127 , weightShape , ScalarType . Int8 ) . cuda ( ) ;
134
+ using var baseline = ( transposeInput , transposeWeight ) switch
135
+ {
136
+ ( false , false ) => torch . matmul ( input . to_type ( ScalarType . Float32 ) , weight . to_type ( ScalarType . Float32 ) ) ,
137
+ ( false , true ) => torch . matmul ( input . to_type ( ScalarType . Float32 ) , weight . to_type ( ScalarType . Float32 ) . t ( ) ) ,
138
+ _ => throw new NotImplementedException ( )
139
+ } ;
140
+ using var result = ( transposeInput , transposeWeight ) switch
141
+ {
142
+ ( false , false ) => Function . Int8GEMM ( input , weight ) ,
143
+ ( false , true ) => Function . Int8GEMM ( input , weight . t ( ) ) ,
144
+ _ => throw new NotImplementedException ( )
145
+ } ;
146
+
147
+ var diff = baseline - result . to_type ( ScalarType . Float32 ) ;
148
+ var avg = diff . abs ( ) . mean ( ) . data < float > ( ) ;
149
+
150
+ Assert . True ( avg [ 0 ] <= 1e-5 ) ;
151
+ }
152
+ }
153
+
72
154
[ CudaTheory ]
73
155
[ InlineData ( ScalarType . Float32 , "fp4" , 64 , 1024 ) ]
74
156
[ InlineData ( ScalarType . Float32 , "nf4" , 64 , 1024 ) ]
@@ -174,4 +256,46 @@ public void TestGemv4Bit3D128(ScalarType dtype, string quantizedDType, int block
174
256
Assert . Equal ( 1 , avg . Count ) ;
175
257
Assert . True ( avg . First ( ) == 0 ) ;
176
258
}
259
+
260
+ [ Fact ]
261
+ public void TestCheckMatmul_ValidInputs ( )
262
+ {
263
+ var A = torch . randint ( 0 , 10 , new long [ ] { 2 , 3 } , ScalarType . Int8 ) ;
264
+ var B = torch . randint ( 0 , 10 , new long [ ] { 3 , 2 } , ScalarType . Int8 ) ;
265
+
266
+ var result = BitsAndByteUtils . CheckMatmul ( A , B , false , false , ScalarType . Int8 ) ;
267
+
268
+ Assert . Equal ( [ 2 , 2 ] , result ) ;
269
+ }
270
+
271
+ [ Fact ]
272
+ public void TestCheckMatmul_InvalidInputs ( )
273
+ {
274
+ var A = torch . randint ( 0 , 10 , new long [ ] { 2 , 3 } , ScalarType . Int8 ) ;
275
+ var B = torch . randint ( 0 , 10 , new long [ ] { 2 , 2 } , ScalarType . Int8 ) ;
276
+
277
+ Assert . Throws < ArgumentException > ( ( ) => BitsAndByteUtils . CheckMatmul ( A , B , false , false , ScalarType . Int8 ) ) ;
278
+ }
279
+
280
+ [ Fact ]
281
+ public void TestCheckMatmul_TransposedInputs ( )
282
+ {
283
+ var A = torch . randint ( 0 , 10 , new long [ ] { 3 , 2 } , ScalarType . Int8 ) ;
284
+ var B = torch . randint ( 0 , 10 , new long [ ] { 3 , 2 } , ScalarType . Int8 ) ;
285
+
286
+ var result = BitsAndByteUtils . CheckMatmul ( A , B , true , false , ScalarType . Int8 ) ;
287
+
288
+ Assert . Equal ( [ 2 , 2 ] , result ) ;
289
+ }
290
+
291
+ [ Fact ]
292
+ public void TestCheckMatmul_NullOutput ( )
293
+ {
294
+ var A = torch . randint ( 0 , 10 , new long [ ] { 2 , 3 } , ScalarType . Int8 ) ;
295
+ var B = torch . randint ( 0 , 10 , new long [ ] { 3 , 2 } , ScalarType . Int8 ) ;
296
+
297
+ var result = BitsAndByteUtils . CheckMatmul ( A , B , false , false , ScalarType . Int8 ) ;
298
+
299
+ Assert . Equal ( [ 2 , 2 ] , result ) ;
300
+ }
177
301
}
0 commit comments