forked from thu-ml/SageAttention
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcore.py
631 lines (513 loc) · 26.9 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import triton
import triton.language as tl
from .triton.quant_per_block import per_block_int8 as per_block_int8_triton
from .triton.quant_per_block_varlen import per_block_int8 as per_block_int8_varlen_triton
from .triton.quant_per_block_hd96 import per_block_int8_hd96
from .triton.attn_qk_int8_per_block_h96 import forward as attn_h96_false
from .triton.attn_qk_int8_per_block_h96_causal import forward as attn_h96_true
from .triton.attn_qk_int8_per_block import forward as attn_false
from .triton.attn_qk_int8_per_block_causal import forward as attn_true
from .triton.attn_qk_int8_block_varlen import forward as attn_false_varlen
from .triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
from ._qattn import qk_int8_sv_f16_accum_f32_attn_per_warp
from ._qattn import qk_int8_sv_f16_accum_f16_attn_per_warp, qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_per_warp
from ._qattn import qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_per_warp, qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_per_warp
from ._qattn import qk_int8_sv_f16_accum_f16_attn_per_warp_buf, qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_per_warp_buf
from .quant import per_block_int8 as per_block_int8_cuda
from .quant import per_warp_int8
from .quant import sub_mean
from .quant import per_channel_fp8
from typing import Any, List, Literal, Optional, Tuple, Union
import warnings
def get_cuda_arch_versions():
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs
def sageattn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
sm_scale: Optional[float] = None,
return_lse: bool = False,
**kwargs: Any,
):
"""
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute capability.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
"""
arch = get_cuda_arch_versions()[q.device.index]
if arch == "sm80":
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")
elif arch == "sm86":
return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse)
elif arch == "sm89":
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
elif arch == "sm90":
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")
else:
raise ValueError(f"Unsupported CUDA architecture: {arch}")
def sageattn_qk_int8_pv_fp16_triton(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tensor_layout: str = "HND",
quantization_backend: str = "triton",
is_causal=False,
sm_scale: Optional[float] = None,
smooth_k: bool = True,
return_lse: bool = False,
**kwargs: Any,
) -> torch.Tensor:
"""
SageAttention with per-block INT8 quantization for Q and K, FP16 PV with FP16 accumulation, implemented using Triton.
The FP16 accumulator is added to a FP32 buffer immediately after each iteration.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
quantization_backend : str
The quantization backend, either "triton" or "cuda".
"cuda" backend offers better performance due to kernel fusion.
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
headdim = q.size(-1)
assert headdim in [64, 96, 128], "headdim should be in [64, 96, 128]."
# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
seq_dim = 1 if tensor_layout == "NHD" else 2
if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
if return_lse:
if tensor_layout == "NHD":
lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
km = None
if dtype == torch.bfloat16 or dtype == torch.float32:
v = v.to(torch.float16)
if sm_scale is None:
sm_scale = 1.0 / (headdim ** 0.5)
if headdim == 96:
q_int8, q_scale, k_int8, k_scale = per_block_int8_hd96(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
if is_causal:
o, lse = attn_h96_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
else:
o, lse = attn_h96_false(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
else:
if quantization_backend == "triton":
q_int8, q_scale, k_int8, k_scale = per_block_int8_triton(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
elif quantization_backend == "cuda":
q_int8, q_scale, k_int8, k_scale = per_block_int8_cuda(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
else:
raise ValueError(f"Unsupported quantization backend: {quantization_backend}")
if is_causal:
o, lse = attn_true(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
else:
o, lse = attn_false(q_int8, k_int8, v, q_scale, k_scale, tensor_layout=tensor_layout, output_dtype=dtype, return_lse=return_lse)
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
def sageattn_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
is_causal: bool = False,
sm_scale: Optional[float] = None,
smooth_k: bool = True,
**kwargs: Any,
) -> torch.Tensor:
"""
Parameters
----------
q : torch.Tensor
The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
cu_seqlens_q : torch.Tensor
The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
cu_seqlens_k : torch.Tensor
The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
max_seqlen_q : int
The maximum sequence length for the query tensor in the batch.
max_seqlen_k : int
The maximum sequence length for the key and value tensors in the batch.
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
Returns
-------
torch.Tensor
The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
- The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
head_dim = q.size(-1)
assert head_dim in [64, 128], "varlen only support head_dim [64, 128]."
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
assert cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous(), "cu_seqlens_q and cu_seqlens_k must be contiguous."
if dtype == torch.bfloat16 or dtype == torch.float32:
v = v.to(torch.float16)
if smooth_k:
km = k.mean(dim=0, keepdim=True) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
k -= km
q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale = per_block_int8_varlen_triton(q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale)
if is_causal:
o = attn_true_varlen(q_int8, k_int8, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, output_dtype=dtype)
else:
o = attn_false_varlen(q_int8, k_int8, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, output_dtype=dtype)
return o
def sageattn_qk_int8_pv_fp16_cuda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
sm_scale: Optional[float] = None,
pv_accum_dtype: str = "fp16",
smooth_k: bool = True,
smooth_v: bool = True,
return_lse: bool = False,
**kwargs: Any,
):
"""
SageAttention with per-warp INT8 quantization for Q and K, FP16 PV with FP16 accumulation, implemented using CUDA.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
pv_accum_dtype : str
The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp16", "fp16+fp32" or "fp32".
- "fp16": PV accumulation is done in fully in FP16. This is the fastest option but may lead to numerical instability. `smooth_v` option will increase the accuracy in cases when the value tensor has a large bias (like in CogVideoX-2b).
- "fp32": PV accumulation is done in FP32. This is the most accurate option but may be slower than "fp16" due to CUDA core overhead.
- "fp16+fp32": PV accumulation is done in FP16, but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
Default: "fp16".
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
smooth_v : bool
Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32".
Default: True.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
- `smooth_v` will introduce slight overhead but will improve the accuracy under some circumstances, as observed in CogVideoX.
"""
dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
_tensor_layout = 0 if tensor_layout == "NHD" else 1
_is_caual = 1 if is_causal else 0
_return_lse = 1 if return_lse else 0
head_dim = q.size(-1)
assert head_dim in [64, 128], "sageattn_qk_int8_pv_fp16_cuda only support head_dim [64, 128]."
# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
if sm_scale is None:
sm_scale = head_dim**-0.5
seq_dim = 1 if _tensor_layout == 0 else 2
if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
if return_lse:
if tensor_layout == "NHD":
lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
km = None
q_int8, q_scale, k_int8, k_scale = per_warp_int8(q, k, km, tensor_layout=tensor_layout)
o = torch.empty(q.size(), dtype=dtype, device=q.device)
if pv_accum_dtype in ["fp32", "fp16+fp32"] and smooth_v:
warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.")
smooth_v = False
if pv_accum_dtype == 'fp32':
v = v.to(torch.float16)
lse = qk_int8_sv_f16_accum_f32_attn_per_warp(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, sm_scale, _return_lse)
elif pv_accum_dtype == "fp16":
if smooth_v:
smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout)
lse = qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_per_warp(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, sm_scale, _return_lse)
else:
v = v.to(torch.float16)
lse = qk_int8_sv_f16_accum_f16_attn_per_warp(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, sm_scale, _return_lse)
elif pv_accum_dtype == "fp16+fp32":
v = v.to(torch.float16)
lse = qk_int8_sv_f16_accum_f16_attn_per_warp_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, sm_scale, _return_lse)
else:
raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}")
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o
def sageattn_qk_int8_pv_fp8_cuda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tensor_layout: str = "HND",
is_causal: bool = False,
sm_scale: Optional[float] = None,
pv_accum_dtype: str = "fp32",
smooth_k: bool = True,
smooth_v: bool = True,
return_lse: bool = False,
**kwargs: Any,
):
"""
SageAttention with per-warp INT8 quantization for Q and K, FP8 PV with FP32 accumulation, implemented using CUDA.
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
pv_accum_dtype : str
The dtype of the accumulation of the product of the value tensor and the attention weights, either "fp32" or "fp32+fp32".
- "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator.
- "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy.
Default: "fp32".
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
smooth_v : bool
Whether to smooth the value tensor by subtracting the mean along the sequence dimension.
smooth_v will be ignored if pv_accum_dtype is "fp32+fp32".
Default: True.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).
Shape: ``[batch_size, num_qo_heads, qo_len]``.
Only returned if `return_lse` is True.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
- `smooth_v` will introduce little overhead but will improve the accuracy under some circumstances, as observed in CogVideoX.
"""
dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
_tensor_layout = 0 if tensor_layout == "NHD" else 1
_is_caual = 1 if is_causal else 0
_return_lse = 1 if return_lse else 0
head_dim = q.size(-1)
assert head_dim in [64, 128], "sageattn_qk_int8_pv_fp16_cuda only support head_dim [64, 128]."
# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
if sm_scale is None:
sm_scale = head_dim**-0.5
seq_dim = 1 if _tensor_layout == 0 else 2
if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
if return_lse:
if tensor_layout == "NHD":
lse_correction = torch.matmul(q.transpose(1, 2), km.transpose(1, 2).transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
lse_correction = torch.matmul(q, km.transpose(2, 3)).squeeze(-1).to(torch.float32)
else:
km = None
q_int8, q_scale, k_int8, k_scale = per_warp_int8(q, k, km, tensor_layout=tensor_layout)
o = torch.empty(q.size(), dtype=dtype, device=q.device)
if pv_accum_dtype == 'fp32+fp32' and smooth_v:
warnings.warn("pv_accum_dtype is 'fp32+fp32', smooth_v will be ignored.")
smooth_v = False
v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, smooth_v=smooth_v)
if pv_accum_dtype == "fp32":
if smooth_v:
lse = qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_per_warp(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, sm_scale, _return_lse)
else:
lse = qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_per_warp(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, sm_scale, _return_lse)
elif pv_accum_dtype == "fp32+fp32":
lse = qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_per_warp_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, sm_scale, _return_lse)
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
return o