Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: causal prefix mask with adjusted tests #2

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cuda/cu116-Linux.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

OS=ubuntu1804
OS=ubuntu2004

wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/cuda/cu117-Linux-env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

CUDA_HOME=/usr/local/cuda-11.7
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
PATH=${CUDA_HOME}/bin:${PATH}

export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export CUDA_HOME=/usr/local/cuda-11.7
15 changes: 15 additions & 0 deletions .github/workflows/cuda/cu117-Linux.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

OS=ubuntu2004

wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget -nv https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb
sudo dpkg -i cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb
sudo cp /var/cuda-repo-${OS}-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/

sudo apt-get -qq update
sudo apt install cuda cuda-nvcc-11-7 cuda-libraries-dev-11-7
sudo apt clean

rm -f https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb
9 changes: 9 additions & 0 deletions .github/workflows/cuda/cu118-Linux-env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

CUDA_HOME=/usr/local/cuda-11.8
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
PATH=${CUDA_HOME}/bin:${PATH}

export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export CUDA_HOME=/usr/local/cuda-11.8
15 changes: 15 additions & 0 deletions .github/workflows/cuda/cu118-Linux.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

OS=ubuntu2004

wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget -nv https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda-repo-${OS}-11-8-local_11.8.0-520.61.05-1_amd64.deb
sudo dpkg -i cuda-repo-${OS}-11-8-local_11.8.0-520.61.05-1_amd64.deb
sudo cp /var/cuda-repo-${OS}-11-8-local/cuda-*-keyring.gpg /usr/share/keyrings/

sudo apt-get -qq update
sudo apt install cuda cuda-nvcc-11-8 cuda-libraries-dev-11-8
sudo apt clean

rm -f https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda-repo-${OS}-11-8-local_11.8.0-520.61.05-1_amd64.deb
35 changes: 23 additions & 12 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
name: Python Package

on:
create:
push:
tags:
- '**'

Expand Down Expand Up @@ -36,15 +36,26 @@ jobs:
strategy:
fail-fast: false
matrix:
# os: [ubuntu-20.04]
os: [ubuntu-18.04]
python-version: ['3.7', '3.8', '3.9', '3.10']
torch-version: [1.11.0, 1.12.0, 1.12.1]
cuda-version: ['113', '116']
os: [ubuntu-20.04]
# python-version: ['3.8', '3.9']
# torch-version: [1.12.1, 2.0.0]
# cuda-version: ['116', '118']
python-version: ['3.9']
torch-version: [1.12.1, 1.13.1, 2.0.1]
cuda-version: ['116', '117', '118']
exclude:
- torch-version: 1.11.0
- torch-version: 1.12.1
cuda-version: '117'
- torch-version: 1.12.1
cuda-version: '118'
- torch-version: 1.13.1
cuda-version: '116'

- torch-version: 1.13.1
cuda-version: '118'
- torch-version: 2.0.1
cuda-version: '116'
- torch-version: 2.0.1
cuda-version: '117'
steps:
- name: Checkout
uses: actions/checkout@v3
Expand Down Expand Up @@ -81,8 +92,8 @@ jobs:

- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses && conda clean -ya
pip install --no-index --no-cache-dir torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cu${{ matrix.cuda-version }}/torch_stable.html
pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses packaging einops setuptools && conda clean -ya
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${{ matrix.cuda-version }} --extra-index-url https://pypi.org/simple
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
Expand All @@ -107,7 +118,7 @@ jobs:
export FORCE_CUDA="1"
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR
export CUDA_INSTALL_DIR=/usr/local/cuda$CUDA_INSTALL_DIR
pip install wheel
python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }}
Expand All @@ -124,4 +135,4 @@ jobs:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
asset_content_type: application/*
4 changes: 3 additions & 1 deletion csrc/flash_attn/src/fmha/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct Mask {
template<typename BInfo>
__device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0)
: actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N)
, actual_seqlen_q(binfo.actual_seqlen_q)
, loop_step_idx(loop_step_idx_) {

const int warp = tidx / Cta_tile::THREADS_PER_WARP;
Expand Down Expand Up @@ -67,7 +68,7 @@ struct Mask {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) {
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
// }
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
return Is_causal ? col_valid && (current_col <= current_row + actual_seqlen_k - actual_seqlen_q) : col_valid;
// return row_valid && col_valid;
}

Expand All @@ -85,6 +86,7 @@ struct Mask {
int col;
const int loop_step_idx;
const int actual_seqlen_k;
const int actual_seqlen_q;
};

} // namespace fmha
16 changes: 14 additions & 2 deletions csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);

static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
int begin;
if (Is_causal) {
int test_val = loop_step_idx * Cta_tile_p::N - (binfo.actual_seqlen_k - binfo.actual_seqlen_q);
if (loop_step_idx * Cta_tile_p::N < binfo.actual_seqlen_k - binfo.actual_seqlen_q) {
begin = 0;
// printf("%d, %d, %d, %d, %d, %d done1\n", Cta_tile_p::N, Cta_tile_p::M, binfo.actual_seqlen_k, binfo.actual_seqlen_q, loop_step_idx, begin);
} else {
begin = test_val / Cta_tile_p::M;
// printf("%d, %d, %d, %d, %d, %d done2\n", Cta_tile_p::N, Cta_tile_p::M, binfo.actual_seqlen_k, binfo.actual_seqlen_q, loop_step_idx, begin);
}
} else {
begin = 0;
}
// Otherwise we'd be reading out-of-bound memory before the loop
if (begin * Cta_tile_p::M >= binfo.actual_seqlen_q) {
// Still need to zero out dk and dv before returning
Expand Down Expand Up @@ -679,7 +691,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|| ((Is_causal) && ((begin + l + 1) * Cta_tile_p::M + (binfo.actual_seqlen_k - binfo.actual_seqlen_q - 1) < (loop_step_idx + 1) * Cta_tile_p::N));
if (is_final_write) {
// if (Is_dropout) {
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
Expand Down
17 changes: 15 additions & 2 deletions csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,20 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i

// Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// int begin = Is_causal ? (loop_step_idx * Cta_tile_p::N) / Cta_tile_p::M : 0;
int begin;
if (Is_causal) {
int test_val = loop_step_idx * Cta_tile_p::N - (binfo.actual_seqlen_k - binfo.actual_seqlen_q);
if (loop_step_idx * Cta_tile_p::N < binfo.actual_seqlen_k - binfo.actual_seqlen_q) {
begin = 0;
// printf("%d, %d, %d, %d, %d, %d done1\n", Cta_tile_p::N, Cta_tile_p::M, binfo.actual_seqlen_k, binfo.actual_seqlen_q, loop_step_idx, begin);
} else {
begin = test_val / Cta_tile_p::M;
// printf("%d, %d, %d, %d, %d, %d done2\n", Cta_tile_p::N, Cta_tile_p::M, binfo.actual_seqlen_k, binfo.actual_seqlen_q, loop_step_idx, begin);
}
} else {
begin = 0;
}
// We want begin to be a multiple of gridDim.z
// This is because the row indices processed by each threadblock must align between the
// loop steps, otherwise we have a dependency between the blocks.
Expand Down Expand Up @@ -620,7 +633,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|| ((Is_causal) && ((begin + l + 1) * Cta_tile_p::M + (binfo.actual_seqlen_k - binfo.actual_seqlen_q - 1) < (loop_step_idx + 1) * Cta_tile_p::N));
#pragma unroll
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
float sum = p_sum_o[jj][0];
Expand Down
18 changes: 15 additions & 3 deletions flash_attn/flash_attn_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ def _fwd_kernel(
if BIAS_TYPE == 'vector':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :])
off_hb_b = off_hb.to(tl.int64)
off_b_b = off_hb_b // nheads
off_h_b = off_hb_b % nheads
start_m_b = start_m.to(tl.int64)
offs_m_b = start_m_b * BLOCK_M + tl.arange(0, BLOCK_M)
b_ptrs = Bias + off_b_b * stride_bb + off_h_b * stride_bh + (offs_m_b[:, None] * stride_bm + offs_n[None, :])
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
Expand Down Expand Up @@ -311,7 +316,11 @@ def _bwd_kernel_one_col_block(
if BIAS_TYPE == 'vector':
b_ptrs = Bias + offs_n
elif BIAS_TYPE == 'matrix':
b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
start_n_b = start_n.to(tl.int64)
begin_m_b = 0 if not IS_CAUSAL else ((start_n_b * BLOCK_N) // BLOCK_M) * BLOCK_M
offs_qm_b = begin_m_b + tl.arange(0, BLOCK_M)
offs_n_b = start_n_b * BLOCK_N + tl.arange(0, BLOCK_N)
b_ptrs = Bias + (offs_qm_b[:, None] * stride_bm + offs_n_b[None, :])
# initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
Expand Down Expand Up @@ -538,7 +547,10 @@ def _bwd_kernel(
DK += off_b * stride_dkb + off_h * stride_dkh
DV += off_b * stride_dvb + off_h * stride_dvh
if BIAS_TYPE != 'none':
Bias += off_b * stride_bb + off_h * stride_bh
if BIAS_TYPE == 'matrix':
Bias += off_b.to(tl.int64) * stride_bb + off_h.to(tl.int64) * stride_bh
else:
Bias += off_b * stride_bb + off_h * stride_bh
# pointer to row-wise quantities in value-like data
D += off_hb * seqlen_q_rounded
LSE += off_hb * seqlen_q_rounded
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def append_nvcc_threads(nvcc_extra_args):
"nvcc": append_nvcc_threads(
[
"-O3",
"-t4",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
Expand Down
127 changes: 127 additions & 0 deletions tests/test_causal_prefix_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Test adapted from https://github.com/openai/triton/blob/0d7e7532279e45672555e344646f5c19c3972331/python/tutorials/06-fused-attention.py
"""
from contextlib import nullcontext
import math
import time

from scipy import stats

import torch

from flash_attn.flash_attn_interface import flash_attn_unpadded_func


def create_causal_mask(q: int, k: int, dtype: torch.dtype, device: torch.device):
return (
(torch.ones((q, k), device=device) - torch.inf).triu(k - q + 1).type(dtype)
)


def attention_ref(q, k, v, sm_scale, causal, device):
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
# for z in range(Z):
# for h in range(H):
# p[:, :, M == 0] = float("-inf")
if causal:
M = create_causal_mask(q.size(2), k.size(2), dtype=dtype, device=device)
p += M
p = torch.softmax(p.float(), dim=-1).type(dtype)
ref_out = torch.matmul(p, v)
return ref_out


torch.manual_seed(0)
repeats = 1
batch_size = 1
nheads = 1
seqlen = 16
n = 16
d = n // nheads
dropout_p = 0.0
causal = True
dtype = torch.bfloat16
device = 'cuda'
test_backward = True


with torch.inference_mode() if not test_backward else nullcontext():
B = 8
H = 12
Q_N_CTX = 350 # 128 * 2 * 2
KV_N_CTX = 350 * 100 # 256 * 2 * 2 * 2
D_HEAD = 64

torch.manual_seed(20)
q = torch.empty((B, H, Q_N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0, std=.5)
k = torch.empty((B, H, KV_N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0, std=.5)
v = torch.empty((B, H, KV_N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0, std=.5)
if test_backward:
q = q.requires_grad_()
k = k.requires_grad_()
v = v.requires_grad_()
cu_seqlens_q = torch.arange(
0, (B + 1) * Q_N_CTX, step=Q_N_CTX, dtype=torch.int32, device=device
)
cu_seqlens_k = torch.arange(
0, (B + 1) * KV_N_CTX, step=KV_N_CTX, dtype=torch.int32, device=device
)

s = time.time()
flash_out = flash_attn_unpadded_func(
q.transpose(1, 2).reshape(B * Q_N_CTX, H, D_HEAD),
k.transpose(1, 2).reshape(B * KV_N_CTX, H, D_HEAD),
v.transpose(1, 2).reshape(B * KV_N_CTX, H, D_HEAD),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=Q_N_CTX,
max_seqlen_k=KV_N_CTX,
dropout_p=dropout_p,
causal=causal,
)
torch.cuda.synchronize()
flash_took = time.time() - s
s = time.time()
ref_out = attention_ref(
q, k, v, sm_scale=1/math.sqrt(D_HEAD), causal=causal, device=device
).transpose(1,2).reshape(B*Q_N_CTX, H, D_HEAD)
torch.cuda.synchronize()
ref_took = time.time() - s

print("allclose", torch.allclose(flash_out, ref_out))
print("max delta", (flash_out - ref_out).abs().max().item())
print("relative max delta", ((flash_out - ref_out).abs().max() / ref_out.abs().mean()).item())
print(stats.spearmanr(flash_out[0,0].float().detach().cpu().numpy(), ref_out[0,0].float().detach().cpu().numpy()))
print(f"ref took: {ref_took:.5f}")
print(f"flash attn took: {flash_took:.5f}")

if test_backward:
dout = torch.randn_like(q).transpose(1, 2).reshape(B * Q_N_CTX, H, D_HEAD)
s = time.time()
ref_out.backward(dout)
torch.cuda.synchronize()
ref_took = time.time() - s
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None

s = time.time()
flash_out.backward(dout)
torch.cuda.synchronize()
flash_took = time.time() - s
flash_dv, v.grad = v.grad.clone(), None
flash_dk, k.grad = k.grad.clone(), None
flash_dq, q.grad = q.grad.clone(), None

for name, ref, flash in zip(
["dv", "dk", "dq"],
[ref_dv, ref_dk, ref_dq],
[flash_dv, flash_dk, flash_dq],
):
print(f"=== evaling {name} ===")
print("allclose", torch.allclose(flash, ref))
print("max delta", (flash - ref).abs().max().item())
print("relative max delta", ((flash - ref).abs().max() / ref.abs().mean()).item())
print(stats.spearmanr(flash[0,0].flatten().float().detach().cpu().numpy(), ref[0,0].flatten().float().detach().cpu().numpy()))
print(f"ref took: {ref_took:.5f}")
print(f"flash attn took: {flash_took:.5f}")
Loading