Skip to content

Commit 89cdaa8

Browse files
authored
[Kernel] Add more dtype support for GGUF kernels (vllm-project#14043)
Signed-off-by: SzymonOzog <[email protected]> Signed-off-by: SzymonOzog <[email protected]>
1 parent b0746fa commit 89cdaa8

File tree

6 files changed

+319
-267
lines changed

6 files changed

+319
-267
lines changed

csrc/quantization/gguf/gguf_kernel.cu

+168-152
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <c10/cuda/CUDAGuard.h>
66

77
#include "cuda_compat.h"
8+
#include "dispatch_utils.h"
89

910
#include "ggml-common.h"
1011
#include "vecdotq.cuh"
@@ -13,7 +14,8 @@
1314
#include "mmq.cuh"
1415

1516
// Q8 gemv
16-
static __global__ void quantize_q8_1(const half* __restrict__ x,
17+
template <typename scalar_t>
18+
static __global__ void quantize_q8_1(const scalar_t* __restrict__ x,
1719
void* __restrict__ vy, const int kx,
1820
const int kx_padded) {
1921
const int ix = blockDim.x * blockIdx.x + threadIdx.x;
@@ -28,7 +30,7 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
2830
const int ib = i_padded / QK8_1; // block index
2931
const int iqs = i_padded % QK8_1; // quant index
3032

31-
const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f;
33+
const float xi = ix < kx ? static_cast<float>(x[iy * kx + ix]) : 0.0f;
3234
float amax = fabsf(xi);
3335
float sum = xi;
3436

@@ -51,14 +53,16 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
5153
y[ib].ds.y = __float2half(sum);
5254
}
5355

54-
static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
56+
template <typename scalar_t>
57+
static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
5558
const int ky, cudaStream_t stream) {
5659
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
5760
const int block_num_x =
5861
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
5962
const dim3 num_blocks(block_num_x, ky, 1);
6063
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
61-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
64+
quantize_q8_1<scalar_t>
65+
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
6266
}
6367

6468
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
@@ -79,101 +83,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
7983
int col = X.sizes()[1];
8084
const int padded = (col + 512 - 1) / 512 * 512;
8185
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
82-
auto options =
83-
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
86+
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
8487
at::Tensor Y = torch::empty({1, row}, options);
8588
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
8689
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
8790
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
88-
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1,
89-
stream);
90-
switch (type) {
91-
case 2:
92-
mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
93-
(half*)Y.data_ptr(), col, row, stream);
94-
break;
95-
case 3:
96-
mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
97-
(half*)Y.data_ptr(), col, row, stream);
98-
break;
99-
case 6:
100-
mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
101-
(half*)Y.data_ptr(), col, row, stream);
102-
break;
103-
case 7:
104-
mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
105-
(half*)Y.data_ptr(), col, row, stream);
106-
break;
107-
case 8:
108-
mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
109-
(half*)Y.data_ptr(), col, row, stream);
110-
break;
111-
case 10:
112-
mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
113-
(half*)Y.data_ptr(), col, row, stream);
114-
break;
115-
case 11:
116-
mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
117-
(half*)Y.data_ptr(), col, row, stream);
118-
break;
119-
case 12:
120-
mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
121-
(half*)Y.data_ptr(), col, row, stream);
122-
break;
123-
case 13:
124-
mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
125-
(half*)Y.data_ptr(), col, row, stream);
126-
break;
127-
case 14:
128-
mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
129-
(half*)Y.data_ptr(), col, row, stream);
130-
break;
131-
case 16:
132-
mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(),
133-
(void*)quant_X.data_ptr(),
134-
(half*)Y.data_ptr(), col, row, stream);
135-
break;
136-
case 17:
137-
mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(),
138-
(void*)quant_X.data_ptr(),
139-
(half*)Y.data_ptr(), col, row, stream);
140-
break;
141-
case 18:
142-
mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(),
143-
(void*)quant_X.data_ptr(),
144-
(half*)Y.data_ptr(), col, row, stream);
145-
break;
146-
case 19:
147-
mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(),
148-
(void*)quant_X.data_ptr(),
149-
(half*)Y.data_ptr(), col, row, stream);
150-
break;
151-
case 20:
152-
mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(),
153-
(void*)quant_X.data_ptr(),
154-
(half*)Y.data_ptr(), col, row, stream);
155-
break;
156-
case 21:
157-
mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(),
158-
(void*)quant_X.data_ptr(),
159-
(half*)Y.data_ptr(), col, row, stream);
160-
break;
161-
case 22:
162-
mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(),
163-
(void*)quant_X.data_ptr(),
164-
(half*)Y.data_ptr(), col, row, stream);
165-
break;
166-
case 23:
167-
mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(),
168-
(void*)quant_X.data_ptr(),
169-
(half*)Y.data_ptr(), col, row, stream);
170-
break;
171-
case 29:
172-
mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(),
173-
(void*)quant_X.data_ptr(),
174-
(half*)Y.data_ptr(), col, row, stream);
175-
break;
176-
}
91+
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
92+
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
93+
(void*)quant_X.data_ptr(), col, 1, stream);
94+
switch (type) {
95+
case 2:
96+
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
97+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
98+
(scalar_t*)Y.data_ptr(), col, row, stream);
99+
break;
100+
case 3:
101+
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
102+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
103+
(scalar_t*)Y.data_ptr(), col, row, stream);
104+
break;
105+
case 6:
106+
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
107+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
108+
(scalar_t*)Y.data_ptr(), col, row, stream);
109+
break;
110+
case 7:
111+
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
112+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
113+
(scalar_t*)Y.data_ptr(), col, row, stream);
114+
break;
115+
case 8:
116+
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
117+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
118+
(scalar_t*)Y.data_ptr(), col, row, stream);
119+
break;
120+
case 10:
121+
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
122+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
123+
(scalar_t*)Y.data_ptr(), col, row, stream);
124+
break;
125+
case 11:
126+
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
127+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
128+
(scalar_t*)Y.data_ptr(), col, row, stream);
129+
break;
130+
case 12:
131+
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
132+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
133+
(scalar_t*)Y.data_ptr(), col, row, stream);
134+
break;
135+
case 13:
136+
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
137+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
138+
(scalar_t*)Y.data_ptr(), col, row, stream);
139+
break;
140+
case 14:
141+
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
142+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
143+
(scalar_t*)Y.data_ptr(), col, row, stream);
144+
break;
145+
case 16:
146+
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
147+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
148+
(scalar_t*)Y.data_ptr(), col, row, stream);
149+
break;
150+
case 17:
151+
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
152+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
153+
(scalar_t*)Y.data_ptr(), col, row, stream);
154+
break;
155+
case 18:
156+
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
157+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
158+
(scalar_t*)Y.data_ptr(), col, row, stream);
159+
break;
160+
case 19:
161+
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
162+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
163+
(scalar_t*)Y.data_ptr(), col, row, stream);
164+
break;
165+
case 20:
166+
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
167+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
168+
(scalar_t*)Y.data_ptr(), col, row, stream);
169+
break;
170+
case 21:
171+
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
172+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
173+
(scalar_t*)Y.data_ptr(), col, row, stream);
174+
break;
175+
case 22:
176+
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
177+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
178+
(scalar_t*)Y.data_ptr(), col, row, stream);
179+
break;
180+
case 23:
181+
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
182+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
183+
(scalar_t*)Y.data_ptr(), col, row, stream);
184+
break;
185+
case 29:
186+
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
187+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
188+
(scalar_t*)Y.data_ptr(), col, row, stream);
189+
break;
190+
}
191+
});
177192
return Y;
178193
}
179194

@@ -184,66 +199,67 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
184199
int padded = (col + 512 - 1) / 512 * 512;
185200
int batch = X.sizes()[0];
186201
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
187-
auto options =
188-
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
202+
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
189203
at::Tensor Y = torch::empty({batch, row}, options);
190204
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
191205
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
192206
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
193-
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col,
194-
batch, stream);
195-
196-
switch (type) {
197-
case 2:
198-
ggml_mul_mat_q4_0_q8_1_cuda(
199-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
200-
col, row, batch, padded, row, stream);
201-
break;
202-
case 3:
203-
ggml_mul_mat_q4_1_q8_1_cuda(
204-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
205-
col, row, batch, padded, row, stream);
206-
break;
207-
case 6:
208-
ggml_mul_mat_q5_0_q8_1_cuda(
209-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
210-
col, row, batch, padded, row, stream);
211-
break;
212-
case 7:
213-
ggml_mul_mat_q5_1_q8_1_cuda(
214-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
215-
col, row, batch, padded, row, stream);
216-
break;
217-
case 8:
218-
ggml_mul_mat_q8_0_q8_1_cuda(
219-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
220-
col, row, batch, padded, row, stream);
221-
break;
222-
case 10:
223-
ggml_mul_mat_q2_K_q8_1_cuda(
224-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
225-
col, row, batch, padded, row, stream);
226-
break;
227-
case 11:
228-
ggml_mul_mat_q3_K_q8_1_cuda(
229-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
230-
col, row, batch, padded, row, stream);
231-
break;
232-
case 12:
233-
ggml_mul_mat_q4_K_q8_1_cuda(
234-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
235-
col, row, batch, padded, row, stream);
236-
break;
237-
case 13:
238-
ggml_mul_mat_q5_K_q8_1_cuda(
239-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
240-
col, row, batch, padded, row, stream);
241-
break;
242-
case 14:
243-
ggml_mul_mat_q6_K_q8_1_cuda(
244-
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
245-
col, row, batch, padded, row, stream);
246-
break;
247-
}
207+
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
208+
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
209+
col, batch, stream);
210+
211+
switch (type) {
212+
case 2:
213+
ggml_mul_mat_q4_0_q8_1_cuda(
214+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
215+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
216+
break;
217+
case 3:
218+
ggml_mul_mat_q4_1_q8_1_cuda(
219+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
220+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
221+
break;
222+
case 6:
223+
ggml_mul_mat_q5_0_q8_1_cuda(
224+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
225+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
226+
break;
227+
case 7:
228+
ggml_mul_mat_q5_1_q8_1_cuda(
229+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
230+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
231+
break;
232+
case 8:
233+
ggml_mul_mat_q8_0_q8_1_cuda(
234+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
235+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
236+
break;
237+
case 10:
238+
ggml_mul_mat_q2_K_q8_1_cuda(
239+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
240+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
241+
break;
242+
case 11:
243+
ggml_mul_mat_q3_K_q8_1_cuda(
244+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
245+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
246+
break;
247+
case 12:
248+
ggml_mul_mat_q4_K_q8_1_cuda(
249+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
250+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
251+
break;
252+
case 13:
253+
ggml_mul_mat_q5_K_q8_1_cuda(
254+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
255+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
256+
break;
257+
case 14:
258+
ggml_mul_mat_q6_K_q8_1_cuda(
259+
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
260+
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
261+
break;
262+
}
263+
});
248264
return Y;
249265
}

0 commit comments

Comments
 (0)