5
5
#include < c10/cuda/CUDAGuard.h>
6
6
7
7
#include " cuda_compat.h"
8
+ #include " dispatch_utils.h"
8
9
9
10
#include " ggml-common.h"
10
11
#include " vecdotq.cuh"
13
14
#include " mmq.cuh"
14
15
15
16
// Q8 gemv
16
- static __global__ void quantize_q8_1 (const half* __restrict__ x,
17
+ template <typename scalar_t >
18
+ static __global__ void quantize_q8_1 (const scalar_t * __restrict__ x,
17
19
void * __restrict__ vy, const int kx,
18
20
const int kx_padded) {
19
21
const int ix = blockDim .x * blockIdx .x + threadIdx .x ;
@@ -28,7 +30,7 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
28
30
const int ib = i_padded / QK8_1; // block index
29
31
const int iqs = i_padded % QK8_1; // quant index
30
32
31
- const float xi = ix < kx ? __half2float (x[iy * kx + ix]) : 0 .0f ;
33
+ const float xi = ix < kx ? static_cast < float > (x[iy * kx + ix]) : 0 .0f ;
32
34
float amax = fabsf (xi);
33
35
float sum = xi;
34
36
@@ -51,14 +53,16 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
51
53
y[ib].ds .y = __float2half (sum);
52
54
}
53
55
54
- static void quantize_row_q8_1_cuda (const half* x, void * vy, const int kx,
56
+ template <typename scalar_t >
57
+ static void quantize_row_q8_1_cuda (const scalar_t * x, void * vy, const int kx,
55
58
const int ky, cudaStream_t stream) {
56
59
const int64_t kx_padded = (kx + 512 - 1 ) / 512 * 512 ;
57
60
const int block_num_x =
58
61
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1 ) / CUDA_QUANTIZE_BLOCK_SIZE;
59
62
const dim3 num_blocks (block_num_x, ky, 1 );
60
63
const dim3 block_size (CUDA_DEQUANTIZE_BLOCK_SIZE, 1 , 1 );
61
- quantize_q8_1<<<num_blocks, block_size, 0 , stream>>> (x, vy, kx, kx_padded);
64
+ quantize_q8_1<scalar_t >
65
+ <<<num_blocks, block_size, 0 , stream>>> (x, vy, kx, kx_padded);
62
66
}
63
67
64
68
torch::Tensor ggml_dequantize (torch::Tensor W, // quant weight
@@ -79,101 +83,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
79
83
int col = X.sizes ()[1 ];
80
84
const int padded = (col + 512 - 1 ) / 512 * 512 ;
81
85
const at::cuda::OptionalCUDAGuard device_guard (device_of (X));
82
- auto options =
83
- torch::TensorOptions ().dtype (torch::kFloat16 ).device (W.device ());
86
+ auto options = torch::TensorOptions ().dtype (X.dtype ()).device (W.device ());
84
87
at::Tensor Y = torch::empty ({1 , row}, options);
85
88
cudaStream_t stream = at::cuda::getCurrentCUDAStream ().stream ();
86
89
options = torch::TensorOptions ().dtype (torch::kInt32 ).device (W.device ());
87
90
at::Tensor quant_X = torch::empty ({1 , padded / 32 * 9 }, options);
88
- quantize_row_q8_1_cuda ((half*)X.data_ptr (), (void *)quant_X.data_ptr (), col, 1 ,
89
- stream);
90
- switch (type) {
91
- case 2 :
92
- mul_mat_vec_q4_0_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
93
- (half*)Y.data_ptr (), col, row, stream);
94
- break ;
95
- case 3 :
96
- mul_mat_vec_q4_1_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
97
- (half*)Y.data_ptr (), col, row, stream);
98
- break ;
99
- case 6 :
100
- mul_mat_vec_q5_0_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
101
- (half*)Y.data_ptr (), col, row, stream);
102
- break ;
103
- case 7 :
104
- mul_mat_vec_q5_1_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
105
- (half*)Y.data_ptr (), col, row, stream);
106
- break ;
107
- case 8 :
108
- mul_mat_vec_q8_0_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
109
- (half*)Y.data_ptr (), col, row, stream);
110
- break ;
111
- case 10 :
112
- mul_mat_vec_q2_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
113
- (half*)Y.data_ptr (), col, row, stream);
114
- break ;
115
- case 11 :
116
- mul_mat_vec_q3_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
117
- (half*)Y.data_ptr (), col, row, stream);
118
- break ;
119
- case 12 :
120
- mul_mat_vec_q4_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
121
- (half*)Y.data_ptr (), col, row, stream);
122
- break ;
123
- case 13 :
124
- mul_mat_vec_q5_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
125
- (half*)Y.data_ptr (), col, row, stream);
126
- break ;
127
- case 14 :
128
- mul_mat_vec_q6_K_q8_1_cuda ((void *)W.data_ptr (), (void *)quant_X.data_ptr (),
129
- (half*)Y.data_ptr (), col, row, stream);
130
- break ;
131
- case 16 :
132
- mul_mat_vec_iq2_xxs_q8_1_cuda ((void *)W.data_ptr (),
133
- (void *)quant_X.data_ptr (),
134
- (half*)Y.data_ptr (), col, row, stream);
135
- break ;
136
- case 17 :
137
- mul_mat_vec_iq2_xs_q8_1_cuda ((void *)W.data_ptr (),
138
- (void *)quant_X.data_ptr (),
139
- (half*)Y.data_ptr (), col, row, stream);
140
- break ;
141
- case 18 :
142
- mul_mat_vec_iq3_xxs_q8_1_cuda ((void *)W.data_ptr (),
143
- (void *)quant_X.data_ptr (),
144
- (half*)Y.data_ptr (), col, row, stream);
145
- break ;
146
- case 19 :
147
- mul_mat_vec_iq1_s_q8_1_cuda ((void *)W.data_ptr (),
148
- (void *)quant_X.data_ptr (),
149
- (half*)Y.data_ptr (), col, row, stream);
150
- break ;
151
- case 20 :
152
- mul_mat_vec_iq4_nl_q8_1_cuda ((void *)W.data_ptr (),
153
- (void *)quant_X.data_ptr (),
154
- (half*)Y.data_ptr (), col, row, stream);
155
- break ;
156
- case 21 :
157
- mul_mat_vec_iq3_s_q8_1_cuda ((void *)W.data_ptr (),
158
- (void *)quant_X.data_ptr (),
159
- (half*)Y.data_ptr (), col, row, stream);
160
- break ;
161
- case 22 :
162
- mul_mat_vec_iq2_s_q8_1_cuda ((void *)W.data_ptr (),
163
- (void *)quant_X.data_ptr (),
164
- (half*)Y.data_ptr (), col, row, stream);
165
- break ;
166
- case 23 :
167
- mul_mat_vec_iq4_xs_q8_1_cuda ((void *)W.data_ptr (),
168
- (void *)quant_X.data_ptr (),
169
- (half*)Y.data_ptr (), col, row, stream);
170
- break ;
171
- case 29 :
172
- mul_mat_vec_iq1_m_q8_1_cuda ((void *)W.data_ptr (),
173
- (void *)quant_X.data_ptr (),
174
- (half*)Y.data_ptr (), col, row, stream);
175
- break ;
176
- }
91
+ VLLM_DISPATCH_FLOATING_TYPES (X.scalar_type (), " ggml_mul_mat_vec_a8" , [&] {
92
+ quantize_row_q8_1_cuda<scalar_t >((scalar_t *)X.data_ptr (),
93
+ (void *)quant_X.data_ptr (), col, 1 , stream);
94
+ switch (type) {
95
+ case 2 :
96
+ mul_mat_vec_q4_0_q8_1_cuda<scalar_t >(
97
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
98
+ (scalar_t *)Y.data_ptr (), col, row, stream);
99
+ break ;
100
+ case 3 :
101
+ mul_mat_vec_q4_1_q8_1_cuda<scalar_t >(
102
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
103
+ (scalar_t *)Y.data_ptr (), col, row, stream);
104
+ break ;
105
+ case 6 :
106
+ mul_mat_vec_q5_0_q8_1_cuda<scalar_t >(
107
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
108
+ (scalar_t *)Y.data_ptr (), col, row, stream);
109
+ break ;
110
+ case 7 :
111
+ mul_mat_vec_q5_1_q8_1_cuda<scalar_t >(
112
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
113
+ (scalar_t *)Y.data_ptr (), col, row, stream);
114
+ break ;
115
+ case 8 :
116
+ mul_mat_vec_q8_0_q8_1_cuda<scalar_t >(
117
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
118
+ (scalar_t *)Y.data_ptr (), col, row, stream);
119
+ break ;
120
+ case 10 :
121
+ mul_mat_vec_q2_K_q8_1_cuda<scalar_t >(
122
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
123
+ (scalar_t *)Y.data_ptr (), col, row, stream);
124
+ break ;
125
+ case 11 :
126
+ mul_mat_vec_q3_K_q8_1_cuda<scalar_t >(
127
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
128
+ (scalar_t *)Y.data_ptr (), col, row, stream);
129
+ break ;
130
+ case 12 :
131
+ mul_mat_vec_q4_K_q8_1_cuda<scalar_t >(
132
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
133
+ (scalar_t *)Y.data_ptr (), col, row, stream);
134
+ break ;
135
+ case 13 :
136
+ mul_mat_vec_q5_K_q8_1_cuda<scalar_t >(
137
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
138
+ (scalar_t *)Y.data_ptr (), col, row, stream);
139
+ break ;
140
+ case 14 :
141
+ mul_mat_vec_q6_K_q8_1_cuda<scalar_t >(
142
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
143
+ (scalar_t *)Y.data_ptr (), col, row, stream);
144
+ break ;
145
+ case 16 :
146
+ mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t >(
147
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
148
+ (scalar_t *)Y.data_ptr (), col, row, stream);
149
+ break ;
150
+ case 17 :
151
+ mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t >(
152
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
153
+ (scalar_t *)Y.data_ptr (), col, row, stream);
154
+ break ;
155
+ case 18 :
156
+ mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t >(
157
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
158
+ (scalar_t *)Y.data_ptr (), col, row, stream);
159
+ break ;
160
+ case 19 :
161
+ mul_mat_vec_iq1_s_q8_1_cuda<scalar_t >(
162
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
163
+ (scalar_t *)Y.data_ptr (), col, row, stream);
164
+ break ;
165
+ case 20 :
166
+ mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t >(
167
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
168
+ (scalar_t *)Y.data_ptr (), col, row, stream);
169
+ break ;
170
+ case 21 :
171
+ mul_mat_vec_iq3_s_q8_1_cuda<scalar_t >(
172
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
173
+ (scalar_t *)Y.data_ptr (), col, row, stream);
174
+ break ;
175
+ case 22 :
176
+ mul_mat_vec_iq2_s_q8_1_cuda<scalar_t >(
177
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
178
+ (scalar_t *)Y.data_ptr (), col, row, stream);
179
+ break ;
180
+ case 23 :
181
+ mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t >(
182
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
183
+ (scalar_t *)Y.data_ptr (), col, row, stream);
184
+ break ;
185
+ case 29 :
186
+ mul_mat_vec_iq1_m_q8_1_cuda<scalar_t >(
187
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
188
+ (scalar_t *)Y.data_ptr (), col, row, stream);
189
+ break ;
190
+ }
191
+ });
177
192
return Y;
178
193
}
179
194
@@ -184,66 +199,67 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
184
199
int padded = (col + 512 - 1 ) / 512 * 512 ;
185
200
int batch = X.sizes ()[0 ];
186
201
const at::cuda::OptionalCUDAGuard device_guard (device_of (X));
187
- auto options =
188
- torch::TensorOptions ().dtype (torch::kFloat16 ).device (W.device ());
202
+ auto options = torch::TensorOptions ().dtype (X.dtype ()).device (W.device ());
189
203
at::Tensor Y = torch::empty ({batch, row}, options);
190
204
cudaStream_t stream = at::cuda::getCurrentCUDAStream ().stream ();
191
205
options = torch::TensorOptions ().dtype (torch::kInt32 ).device (W.device ());
192
206
at::Tensor quant_X = torch::empty ({batch, padded / 32 * 9 }, options);
193
- quantize_row_q8_1_cuda ((half*)X.data_ptr (), (void *)quant_X.data_ptr (), col,
194
- batch, stream);
195
-
196
- switch (type) {
197
- case 2 :
198
- ggml_mul_mat_q4_0_q8_1_cuda (
199
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
200
- col, row, batch, padded, row, stream);
201
- break ;
202
- case 3 :
203
- ggml_mul_mat_q4_1_q8_1_cuda (
204
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
205
- col, row, batch, padded, row, stream);
206
- break ;
207
- case 6 :
208
- ggml_mul_mat_q5_0_q8_1_cuda (
209
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
210
- col, row, batch, padded, row, stream);
211
- break ;
212
- case 7 :
213
- ggml_mul_mat_q5_1_q8_1_cuda (
214
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
215
- col, row, batch, padded, row, stream);
216
- break ;
217
- case 8 :
218
- ggml_mul_mat_q8_0_q8_1_cuda (
219
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
220
- col, row, batch, padded, row, stream);
221
- break ;
222
- case 10 :
223
- ggml_mul_mat_q2_K_q8_1_cuda (
224
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
225
- col, row, batch, padded, row, stream);
226
- break ;
227
- case 11 :
228
- ggml_mul_mat_q3_K_q8_1_cuda (
229
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
230
- col, row, batch, padded, row, stream);
231
- break ;
232
- case 12 :
233
- ggml_mul_mat_q4_K_q8_1_cuda (
234
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
235
- col, row, batch, padded, row, stream);
236
- break ;
237
- case 13 :
238
- ggml_mul_mat_q5_K_q8_1_cuda (
239
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
240
- col, row, batch, padded, row, stream);
241
- break ;
242
- case 14 :
243
- ggml_mul_mat_q6_K_q8_1_cuda (
244
- (void *)W.data_ptr (), (void *)quant_X.data_ptr (), (half*)Y.data_ptr (),
245
- col, row, batch, padded, row, stream);
246
- break ;
247
- }
207
+ VLLM_DISPATCH_FLOATING_TYPES (X.scalar_type (), " ggml_mul_mat_a8" , [&] {
208
+ quantize_row_q8_1_cuda ((scalar_t *)X.data_ptr (), (void *)quant_X.data_ptr (),
209
+ col, batch, stream);
210
+
211
+ switch (type) {
212
+ case 2 :
213
+ ggml_mul_mat_q4_0_q8_1_cuda (
214
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
215
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
216
+ break ;
217
+ case 3 :
218
+ ggml_mul_mat_q4_1_q8_1_cuda (
219
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
220
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
221
+ break ;
222
+ case 6 :
223
+ ggml_mul_mat_q5_0_q8_1_cuda (
224
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
225
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
226
+ break ;
227
+ case 7 :
228
+ ggml_mul_mat_q5_1_q8_1_cuda (
229
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
230
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
231
+ break ;
232
+ case 8 :
233
+ ggml_mul_mat_q8_0_q8_1_cuda (
234
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
235
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
236
+ break ;
237
+ case 10 :
238
+ ggml_mul_mat_q2_K_q8_1_cuda (
239
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
240
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
241
+ break ;
242
+ case 11 :
243
+ ggml_mul_mat_q3_K_q8_1_cuda (
244
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
245
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
246
+ break ;
247
+ case 12 :
248
+ ggml_mul_mat_q4_K_q8_1_cuda (
249
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
250
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
251
+ break ;
252
+ case 13 :
253
+ ggml_mul_mat_q5_K_q8_1_cuda (
254
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
255
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
256
+ break ;
257
+ case 14 :
258
+ ggml_mul_mat_q6_K_q8_1_cuda (
259
+ (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
260
+ (scalar_t *)Y.data_ptr (), col, row, batch, padded, row, stream);
261
+ break ;
262
+ }
263
+ });
248
264
return Y;
249
265
}
0 commit comments