Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update asm pa for blockscale (256/128,128) #231

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiter/ops/attention.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,8 @@ def pa_fwd_asm(
K_QScale: Optional[torch.Tensor],
V_QScale: Optional[torch.Tensor],
out_: Optional[torch.Tensor] = None,
high_precision: Optional[int] = 1 # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
high_precision: Optional[int] = 1, # [0, 1, 2] 2 is the highest precision, this is only for fp8 kvcache
block_shape: Optional[tuple[int,int]] = None,
) -> torch.Tensor: ...


3 changes: 2 additions & 1 deletion csrc/include/attention_asm.h
Original file line number Diff line number Diff line change
@@ -12,4 +12,5 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea
std::optional<torch::Tensor> &K_QScale,
std::optional<torch::Tensor> &V_QScale,
std::optional<torch::Tensor> &out_,
std::optional<int> high_precision = 1);
std::optional<int> high_precision = 1,
std::optional<std::tuple<int, int>> block_shape = std::nullopt);
3 changes: 2 additions & 1 deletion csrc/include/rocm_ops.hpp
100644 → 100755
Original file line number Diff line number Diff line change
@@ -40,7 +40,8 @@
py::arg("K_QScale") = std::nullopt, \
py::arg("V_QScale") = std::nullopt, \
py::arg("out_") = std::nullopt, \
py::arg("high_precision") = 1);
py::arg("high_precision") = 1, \
py::arg("block_shape") = std::nullopt);

#define ATTENTION_CK_PYBIND \
m.def("pa_fwd_naive", &pa_fwd_naive, "pa_fwd_naive", \
23 changes: 21 additions & 2 deletions csrc/py_itfs_cu/asm_pa.cpp
Original file line number Diff line number Diff line change
@@ -52,7 +52,8 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea
std::optional<torch::Tensor> &K_QScale,
std::optional<torch::Tensor> &V_QScale,
std::optional<torch::Tensor> &out_,
std::optional<int> high_precision = 1)
std::optional<int> high_precision = 1,
std::optional<std::tuple<int, int>> block_shape = std::nullopt)
{
torch::Tensor output = out_.value_or(torch::empty_like(Q));
int batch = context_lens.size(0);
@@ -106,7 +107,25 @@ torch::Tensor pa_fwd(torch::Tensor &Q, // [num_seqs, num_heads, hea
AiterAsmKernel *impl_ptr = nullptr;
if (K_QScale)
{
if (Q.dtype() == at::ScalarType::Half)
if (block_shape.has_value())
{
if (block_shape.value() == std::make_tuple(128, 128) && Q.dtype() == at::ScalarType::BFloat16 && K.dtype() == at::ScalarType::Float8_e4m3fnuz)
{
static AiterAsmKernel impl_a16w16_b16_f8_blockscale128("pa_a16w8_2tg_g8_f8_kv128_bf16", "pa_a16w8_2tg_g8_f8_kv128_bf16.co");
impl_ptr = &impl_a16w16_b16_f8_blockscale128;
}
else if (block_shape.value() == std::make_tuple(256, 128) && Q.dtype() == at::ScalarType::BFloat16 && K.dtype() == at::ScalarType::Float8_e4m3fnuz)
{
static AiterAsmKernel impl_a16w16_b16_f8_blockscale256("pa_a16w8_2tg_g8_f8_kv256_bf16", "pa_a16w8_2tg_g8_f8_kv256_bf16.co");
impl_ptr = &impl_a16w16_b16_f8_blockscale256;
}
else
{
TORCH_CHECK(false,
__func__, ": only support block_shape == (128, 128) | (256, 128), Q dtype == BFloat16 and quantType == fp8 for now");
}
}
else if (Q.dtype() == at::ScalarType::Half)
{
if (K.dtype() == at::ScalarType::Byte || K.dtype() == at::ScalarType::Char)
{
Binary file added hsa/pa_a16w8_2tg_g8_f8_kv128_bf16.co
Binary file not shown.
Binary file added hsa/pa_a16w8_2tg_g8_f8_kv256_bf16.co
Binary file not shown.
129 changes: 98 additions & 31 deletions op_tests/test_pa.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import torch
import aiter
from aiter import paged_attn as ops
from aiter.test_common import checkAllclose, perftest, tensor_dump, tensor_load
from aiter.test_common import checkAllclose, perftest, tensor_dump, tensor_load, benchmark
from aiter import pertoken_quant

uniform_range = (-1, 1)
@@ -28,6 +28,7 @@
# // same as 8bit per token quant but 4 bit
'KV_4BIT_PER_TOKEN',
'KV_8BIT_PER_TENSOR',
'KV_8BIT_PER_BLOCK',
]


@@ -376,7 +377,8 @@ def run_aiter_asm(query,
max_num_blocks,
k_scale=None,
v_scale=None,
high_precision=0):
high_precision=0,
block_shape=None):
return aiter.pa_fwd_asm(
query,
k_cache,
@@ -387,7 +389,8 @@ def run_aiter_asm(query,
k_scale,
v_scale,
None,
high_precision
high_precision,
block_shape
)


@@ -461,7 +464,7 @@ def asm_V_shuffle(VC):
VC = VC.permute(0, 1, 3, 2, 4).contiguous()
return VC


@benchmark()
def test_paged_attention(
ctx_lens: int,
num_seqs: int,
@@ -472,7 +475,8 @@ def test_paged_attention(
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str
device: str,
block_shape: Optional[Tuple[int, int]] = None,
) -> None:
torch.set_default_device(device)
# Using default kv_scale
@@ -486,6 +490,8 @@ def test_paged_attention(
num_queries_per_kv = num_query_heads // num_kv_heads
max_seq_len = ctx_lens
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
if block_shape is not None:
max_num_blocks_per_seq = ((max_seq_len + block_shape[0] - 1) // block_shape[0]) * block_shape[0] // block_size
num_blocks = max_num_blocks_per_seq*num_seqs
print(f'{debug_mode=}')

@@ -510,10 +516,17 @@ def test_paged_attention(
# Create the block tables.
block_tables_lst: List[List[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
if block_shape is None:
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
else:
serial_block_num = block_shape[0] // block_size
random_block = (random.randint(0, num_blocks - max_num_blocks_per_seq) //serial_block_num) * serial_block_num
block_table = list(
range(random_block, random_block + max_num_blocks_per_seq)
)
block_tables_lst.append(block_table)

block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
@@ -569,7 +582,8 @@ def test_paged_attention(
(2, torch.float8_e4m3fnuz),
(2, torch.int8),
(4, torch.float8_e4m3fnuz),
]:
] if block_shape is None else \
[(5, torch.float8_e4m3fnuz)]:
quant_algo = ck_naive_quant_algo[quant_algo_]
if quant_algo == "NO":
k_quant_, k_scale_, v_quant_, v_scale_ = k_cache, torch.empty(
@@ -634,6 +648,47 @@ def test_paged_attention(
# )
# checkAllclose(out_aiter_asm, out_aiter_naive,
# msg=f'golden vs ck_naive(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_}):{time_aiter_naive:.2f} us......')
elif quant_algo == "KV_8BIT_PER_BLOCK":
assert block_shape in [(128, 128), (256, 128)], "KV_8BIT_PER_BLOCK only supports block_shape (128, 128) or (256, 128)"
assert head_size == block_shape[1], "KV_8BIT_PER_BLOCK only supports head_size == block_shape[1]"
assert num_blocks % (block_shape[0] // block_size) == 0, "KV_8BIT_PER_BLOCK only supports num_blocks multiple of (block_shape[0] // block_size)"

x = k_cache.shape[-1]
k_cache_permute = k_cache.view(num_blocks // (block_shape[0] // block_size), (block_shape[0] // block_size), num_kv_heads, head_size//x, block_size, x).permute(0, 2, 1, 3, 4, 5).contiguous()
k_quant_, k_scale_asm = pertoken_quant(k_cache_permute.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, -1), quant_dtype=torch.float8_e4m3fnuz)
k_cache_permute = (k_quant_.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, -1).to(torch.float) * k_scale_asm.to(torch.float)).to(k_cache.dtype)
k_cache = k_cache_permute.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, (block_shape[0] // block_size), head_size//x, block_size, x).permute(0, 2, 1, 3, 4, 5).contiguous()
del k_cache_permute
k_cache = k_cache.view(num_blocks, num_kv_heads, head_size//x, block_size, x)
k_quant_ = k_quant_.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, (block_shape[0] // block_size), head_size//x, block_size, x).permute(0, 2, 1, 3, 5, 4).contiguous()
x = 16 // torch.float8_e4m3fnuz.itemsize
k_quant_ = k_quant_.view(num_blocks, num_kv_heads, head_size//x, x, block_size).permute(0, 1, 2, 4, 3).contiguous()
k_scale_asm = k_scale_asm.view(num_blocks // (block_shape[0] // block_size), num_kv_heads).permute(1, 0).contiguous()

v_cache_permute = v_cache.view(num_blocks // (block_shape[0] // block_size), (block_shape[0] // block_size), num_kv_heads, head_size, block_size).permute(0, 2, 1, 3, 4).contiguous()
v_quant_, v_scale_asm = pertoken_quant(v_cache_permute.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, -1), quant_dtype=torch.float8_e4m3fnuz)
v_cache_permute = (v_quant_.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, -1).to(torch.float) * v_scale_asm.to(torch.float)).to(v_cache.dtype)
v_cache = v_cache_permute.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, (block_shape[0] // block_size), head_size, block_size).permute(0, 2, 1, 3, 4).contiguous()
del v_cache_permute
v_cache = v_cache.view(num_blocks, num_kv_heads, head_size, block_size)
v_quant_ = v_quant_.view(num_blocks // (block_shape[0] // block_size), num_kv_heads, (block_shape[0] // block_size), head_size, block_size).permute(0, 2, 1, 3, 4).contiguous()
v_quant_ = v_quant_.view(num_blocks, num_kv_heads, head_size, block_size)
v_scale_asm = v_scale_asm.view(num_blocks // (block_shape[0] // block_size), num_kv_heads).permute(1, 0).contiguous()

out_golden, time_aiter = run_aiter(
query,
k_cache,
v_cache,
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
alibi_slopes,
k_scale,
v_scale,
)

if quant_algo_ != 0:
out_aiter_asm, time_aiter_asm = run_aiter_asm(
@@ -650,9 +705,10 @@ def test_paged_attention(
max_num_blocks_per_seq,
k_scale_asm,
v_scale_asm,
block_shape=block_shape
)
checkAllclose(out_golden, out_aiter_asm,
msg=f'golden vs aiter_asm:{time_aiter_asm:.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})')
msg=f'golden vs aiter_asm:{time_aiter_asm:.2f} us......(quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_}, {block_shape=})')

if dtype in [torch.bfloat16, torch.float16] and quant_algo_ == 2 and cache_type_ == torch.float8_e4m3fnuz:
if dtype == torch.bfloat16:
@@ -682,26 +738,27 @@ def test_paged_attention(
# if quant_algo == "KV_8BIT_PER_TENSOR":
# q_quant_, q_scale_ = aiter.per_tensor_quant(
# query, quant_dtype=cache_type_)
out_native, time_native = run_native(
query,
# q_quant_,
k_quant_,
v_quant_,
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
# scale*q_scale_.item(),
alibi_slopes,
k_scale_,
v_scale_,
num_queries_per_kv,
dtype
)
checkAllclose(
out_golden, out_native, msg=f'golden vs torch_native: {time_native:.2f} us...... (quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})')
if quant_algo != "KV_8BIT_PER_BLOCK":
out_native, time_native = run_native(
query,
# q_quant_,
k_quant_,
v_quant_,
block_tables,
seq_lens,
max_seq_len,
kv_cache_dtype,
num_kv_heads,
scale,
# scale*q_scale_.item(),
alibi_slopes,
k_scale_,
v_scale_,
num_queries_per_kv,
dtype
)
checkAllclose(
out_golden, out_native, msg=f'golden vs torch_native: {time_native:.2f} us...... (quant:{ck_naive_quant_algo[quant_algo_]}, kvcache:{cache_type_})')

if debug_mode == DUMP:
dump_input(query,
@@ -751,3 +808,13 @@ def test_paged_attention(
for dtype in [torch.float16, torch.bfloat16]:
test_paged_attention(ctx_len, 128, num_heads, 128, False, 16,
dtype, "auto", 0, "cuda:0")

for num_heads in [(4, 1), (8, 1), (32, 8)]:
for ctx_len in [7, 26, 57, 66, 109, 128, 257, 282, 4097]:
for dtype in [torch.bfloat16]:
test_paged_attention(ctx_len, 128, num_heads, 128, False, 16, torch.bfloat16, "auto", 0, "cuda:0", block_shape=(128,128))

for num_heads in [(4, 1), (8, 1), (32, 8)]:
for ctx_len in [7, 26, 57, 66, 109, 128, 257, 282, 4097]:
for dtype in [torch.bfloat16]:
test_paged_attention(ctx_len, 128, num_heads, 128, False, 16, torch.bfloat16, "auto", 0, "cuda:0", block_shape=(256,128))