Skip to content

Ttruong/v2.6.3 alibi as bh bias #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
aebc5a7
convert alibi bias to simply bh bias (same as alibi without dependenc…
timt51 Sep 6, 2024
cfca2e8
publish fewer versions
timt51 Sep 6, 2024
14b704d
fix diagonal bias
timt51 Sep 7, 2024
5996687
try just setting bias to 0 when Col_idx_only
timt51 Sep 7, 2024
3b267ff
try disabling Col_idx_only if alibi
timt51 Sep 7, 2024
b9ec215
fix bh bias diagonal handling
timt51 Sep 8, 2024
42a895e
compile for pytorch 2.5.1
timt51 Feb 5, 2025
f124d98
fix TORCH_CUDA_VERSION env var for pytorch 2.5
timt51 Feb 5, 2025
4f8b153
adjust alibi for non causal too
timt51 Feb 7, 2025
d2dd0fd
compile for torch 2.4.1 too
timt51 Feb 7, 2025
980f524
make a version that always adds the alibi slope...
timt51 Feb 7, 2025
c42bf6b
...and only compile for py310 and torch24
timt51 Feb 7, 2025
a8d7fcc
redo causal and noncausal alibi for diagonal
timt51 Feb 7, 2025
599e322
bidirectional diagonal handling - take into account seqlens
timt51 Feb 7, 2025
fdc50f3
fix diagonal formula
timt51 Feb 7, 2025
a645edd
diagonal noncausal try accounting for max_seqlens too
timt51 Feb 8, 2025
f5ce6ee
follow the original expression more closely
timt51 Feb 8, 2025
efbbaf4
also modify alibi.h
timt51 Feb 8, 2025
88e1cc9
Update mask.h
timt51 Feb 8, 2025
4bda2db
alibi.h use the non max seqlen formula which seems more correct actua…
timt51 Feb 8, 2025
5b937eb
Revert "alibi.h use the non max seqlen formula which seems more corre…
timt51 Mar 24, 2025
cb7632e
Revert "Update mask.h"
timt51 Mar 24, 2025
7104b34
publish for pytorch 2.5.1 and 2.6.0 too
timt51 Mar 24, 2025
4ebb8be
export minv and maxv for pytorch 2.6
timt51 Mar 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ jobs:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0']
cuda-version: ['11.8.0', '12.3.2']
python-version: ['3.10']
torch-version: ['2.4.1', '2.5.1', '2.6.0']
cuda-version: ['11.8.0']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ['FALSE', 'TRUE']
cxx11_abi: ['FALSE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.2 does not support Python 3.12
Expand Down Expand Up @@ -118,8 +118,8 @@ jobs:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121, '2.5': 124, '2.6': 126}[env['MATRIX_TORCH_VERSION']]; \
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
Expand Down Expand Up @@ -179,34 +179,34 @@ jobs:
asset_name: ${{env.wheel_name}}
asset_content_type: application/*

publish_package:
name: Publish package
needs: [build_wheels]

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3

- uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Install dependencies
run: |
pip install ninja packaging setuptools wheel twine
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu

- name: Build core package
env:
FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
run: |
python setup.py sdist --dist-dir=dist

- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*
# publish_package:
# name: Publish package
# needs: [build_wheels]

# runs-on: ubuntu-latest

# steps:
# - uses: actions/checkout@v3

# - uses: actions/setup-python@v4
# with:
# python-version: '3.10'

# - name: Install dependencies
# run: |
# pip install ninja packaging setuptools wheel twine
# # We don't want to download anything CUDA-related here
# pip install torch --index-url https://download.pytorch.org/whl/cpu

# - name: Build core package
# env:
# FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
# run: |
# python setup.py sdist --dist-dir=dist

# - name: Deploy
# env:
# TWINE_USERNAME: "__token__"
# TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
# run: |
# python -m twine upload dist/*
20 changes: 13 additions & 7 deletions csrc/flash_attn/src/alibi.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,20 @@ struct Alibi {
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q);
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);
}
}
}
}
Expand All @@ -61,7 +67,7 @@ struct Alibi {
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
tensor(make_coord(i, mi), make_coord(j, nj)) += (((row_idx + max_seqlen_k - max_seqlen_q - col_idx) == 0) ? 0 : alibi_slope);
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions csrc/flash_attn/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ struct Mask {
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
// Do we need both row and column indices, or just column incides?
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
static constexpr bool Col_idx_only = !Has_alibi && !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Col_idx_only) {
Expand Down Expand Up @@ -178,9 +178,10 @@ struct Mask {
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);

} else {
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
tensor(make_coord(i, mi), make_coord(j, nj)) += (((row_idx + max_seqlen_k - max_seqlen_q - col_idx) == 0) ? 0 : alibi_slope);

}
}
Expand Down
Loading