Skip to content

Commit 4c20539

Browse files
authored
Merge branch 'HazyResearch:main' into main
2 parents 08adf1f + 6b4a482 commit 4c20539

File tree

274 files changed

+23736
-854
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

274 files changed

+23736
-854
lines changed

.gitignore

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
5+
# C extensions
6+
*.so
7+
8+
# Distribution / packaging
9+
bin/
10+
build/
11+
develop-eggs/
12+
dist/
13+
eggs/
14+
lib/
15+
lib64/
16+
parts/
17+
sdist/
18+
var/
19+
*.egg-info/
20+
.installed.cfg
21+
*.egg

MANIFEST.in

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
recursive-include csrc *.cu
2+
recursive-include csrc *.h
3+
recursive-include csrc *.cuh
4+
recursive-include csrc *.cpp
5+
6+
recursive-include flash_attn *.cu
7+
recursive-include flash_attn *.h
8+
recursive-include flash_attn *.cuh
9+
recursive-include flash_attn *.cpp

Makefile

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
clean_dist:
3+
rm -rf dist/*
4+
5+
create_dist: clean_dist
6+
python setup.py sdist
7+
8+
upload_package: create_dist
9+
twine upload dist/*

README.md

+33-8
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,27 @@ Paper: https://arxiv.org/abs/2205.14135
88
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
99
![FlashAttention](assets/flashattn_banner.jpg)
1010

11-
#### Triton implementation of FlashAttention
11+
## Usage
12+
13+
We've been very happy to see FlashAttention being widely adopted in such a short
14+
time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md)
15+
contains a partial list of places where FlashAttention is being used.
16+
17+
## Full model code and training script
18+
19+
We have released the full GPT model
20+
[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
21+
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
22+
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
23+
compared to the baseline implementation from Huggingface, reaching up to 189
24+
TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need
25+
any activation checkpointing).
26+
27+
We also include a training
28+
[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to
29+
train GPT2 on Openwebtext and GPT3 on The Pile.
30+
31+
## Triton implementation of FlashAttention
1232

1333
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
1434
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
@@ -18,9 +38,14 @@ and experiment with. The notations in the Triton implementation are also closer
1838
to what's used in our paper.
1939

2040

21-
## Alpha release (0.1).
41+
## Beta release (0.2).
42+
43+
To install (requiring CUDA 11, NVCC, and an Turing or Ampere GPU):
44+
```sh
45+
pip install flash-attn
46+
```
2247

23-
To compile (requiring CUDA 11, NVCC, and an Turing or Ampere GPU):
48+
Alternatively you can compile from source:
2449
```
2550
python setup.py install
2651
```
@@ -38,15 +63,15 @@ FlashAttention currently supports:
3863
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100.
3964

4065
Our tentative roadmap:
41-
1. [Jun 2022] Make package pip-installable.
66+
1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
4267
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
4368
3. [Jun 2022] Refactor to use Cutlass.
4469
4. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
4570
5. ~~[Jun 2022] Support bf16~~[Done].
4671
6. ~~[Jul 2022] Implement cross-attention~~[Done].
4772
7. ~~[Jul 2022] Support head dimension 128~~[Done].
4873
8. [Jul 2022] Support SM70 GPUs (V100).
49-
9. [Aug 2022] Fuse rotary embedding.
74+
9. ~~[Aug 2022] Fuse rotary embedding~~[Done].
5075
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
5176

5277
## Speedup and Memory Savings
@@ -148,10 +173,10 @@ and for his thoughtful answers to our questions about CUDA.
148173
## Citation
149174
If you use this codebase, or otherwise found our work valuable, please cite:
150175
```
151-
@article{dao2022flashattention,
152-
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
176+
@inproceedings{dao2022flashattention,
177+
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
153178
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
154-
journal={arXiv preprint arXiv:2205.14135},
179+
booktitle={Advances in Neural Information Processing Systems},
155180
year={2022}
156181
}
157182
```

assets/gpt2_training_curve.jpg

168 KB
Loading

assets/gpt2_training_efficiency.jpg

367 KB
Loading

assets/gpt3_training_curve.jpg

183 KB
Loading

assets/gpt3_training_efficiency.jpg

382 KB
Loading

csrc/flash_attn/fmha_api.cpp

+31-18
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ void set_params_dgrad(FMHA_dgrad_params &params,
176176
params.dsoftmax_sum = dsoftmax_sum_d;
177177
}
178178

179+
void run_fmha_fwd(Launch_params<FMHA_fprop_params> &launch_params) {
180+
if (launch_params.params.d <= 32) {
181+
run_fmha_fwd_hdim32(launch_params);
182+
} else if (launch_params.params.d <= 64) {
183+
run_fmha_fwd_hdim64(launch_params);
184+
} else if (launch_params.params.d <= 128) {
185+
run_fmha_fwd_hdim128(launch_params);
186+
}
187+
}
188+
179189
std::vector<at::Tensor>
180190
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
181191
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
@@ -299,21 +309,29 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
299309
// state
300310
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
301311
int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32;
302-
at::PhiloxCudaState rng_engine_inputs;
303312

304313
if( is_dropout ) {
305314
// See Note [Acquire lock when using random generators]
306315
std::lock_guard<std::mutex> lock(gen->mutex_);
307316
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
308317
}
309318

310-
run_fmha_fp16_sm80(launch_params);
319+
run_fmha_fwd(launch_params);
311320

312321
std::vector<at::Tensor> result = {softmax_lse};
313322
if (return_softmax) {result.push_back(s);}
314323
return result;
315324
}
316325

326+
void run_fmha_bwd(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
327+
if (params.d <= 32) {
328+
run_fmha_bwd_hdim32(params, stream, configure);
329+
} else if (params.d <= 64) {
330+
run_fmha_bwd_hdim64(params, stream, configure);
331+
} else if (params.d <= 128) {
332+
run_fmha_bwd_hdim128(params, stream, configure);
333+
}
334+
}
317335

318336
std::vector<at::Tensor>
319337
mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
@@ -341,7 +359,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
341359
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
342360
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
343361
TORCH_CHECK(is_sm8x || is_sm75);
344-
auto launch = &run_fmha_dgrad_fp16_sm80;
362+
auto launch = &run_fmha_bwd;
345363

346364
bool is_dropout = p_dropout > 0.0;
347365
auto stream = at::cuda::getCurrentCUDAStream().stream();
@@ -454,17 +472,13 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
454472

455473
launch(params, stream, /*configure=*/true);
456474

457-
at::Tensor dk_accum, dv_accum;
458475
if (params.num_splits > 1) {
459-
// dk_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
460-
// dv_accum = torch::zeros({total_k, num_heads, head_size}, opts.dtype(at::kFloat));
461-
// params.dk_accum_ptr = dk_accum.data_ptr();
462-
// params.dv_accum_ptr = dv_accum.data_ptr();
463-
dk.zero_();
464-
dv.zero_();
465-
} else {
466-
// params.dk_accum_ptr = nullptr;
467-
// params.dv_accum_ptr = nullptr;
476+
if (!dq_tmp.defined()) {
477+
dq_tmp = torch::zeros({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
478+
params.o_tmp_ptr = dq_tmp.data_ptr(); // o_tmp stores dq_tmp in the backward pass
479+
} else {
480+
dq_tmp.zero_();
481+
}
468482
}
469483

470484
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
@@ -481,10 +495,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
481495

482496
launch(params, stream, /*configure=*/false);
483497

484-
// if (params.num_splits > 1) {
485-
// dk.copy_(dk_accum);
486-
// dv.copy_(dv_accum);
487-
// }
498+
if (params.num_splits > 1) {
499+
dq.copy_(dq_tmp);
500+
}
501+
488502
return { dq, dk, dv, softmax_d };
489503
}
490504

@@ -597,7 +611,6 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
597611
// number of times random will be generated per thread, to offset philox counter in thc random
598612
// state
599613
int64_t counter_offset = launch_params.elts_per_thread;
600-
at::PhiloxCudaState rng_engine_inputs;
601614

602615
if( is_dropout ) {
603616
// See Note [Acquire lock when using random generators]

csrc/flash_attn/src/.DS_Store

-6 KB
Binary file not shown.

csrc/flash_attn/src/fmha.h

+8-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
#include <ATen/cuda/CUDAGeneratorImpl.h>
3737
#endif
3838

39-
#include <ATen/cuda/CUDAGraphsUtils.cuh>
39+
#include <ATen/cuda/CUDAContext.h>
40+
#include <ATen/cuda/detail/UnpackRaw.cuh>
4041

4142
#include <fmha_utils.h>
4243

@@ -195,9 +196,13 @@ struct Launch_params{
195196

196197
////////////////////////////////////////////////////////////////////////////////////////////////////
197198

198-
void run_fmha_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params);
199+
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params);
200+
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params);
201+
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params);
199202

200-
void run_fmha_dgrad_fp16_sm80(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
203+
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
204+
void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
205+
void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure);
201206

202207
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
203208

csrc/flash_attn/src/fmha/gmem_tile.h

+21-51
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,6 @@
3434

3535
namespace fmha {
3636

37-
// template <typename half2_t>
38-
// inline __device__ void atomic_add_CAS(half2_t *address, const half2_t val) {
39-
// uint32_t *address_as_ui = (uint32_t *)address;
40-
// uint32_t old = *address_as_ui;
41-
// uint32_t assumed;
42-
// do {
43-
// assumed = old;
44-
// half2_t sum = __hadd2(val, reinterpret_cast<half2_t(&)>(old));
45-
// old = atomicCAS(address_as_ui, assumed, reinterpret_cast<uint32_t(&)>(sum));
46-
// } while (assumed != old);
47-
// }
48-
49-
////////////////////////////////////////////////////////////////////////////////////////////////////
50-
5137
template<
5238
// The dimensions of the tile computed by the CTA.
5339
typename Cta_tile_,
@@ -148,43 +134,6 @@ struct Gmem_tile_qkv {
148134
}
149135
}
150136

151-
template <typename elem_type>
152-
inline __device__ void atomic_add(const uint4 (&data)[LDGS]) {
153-
int row_ = tidx_ / THREADS_PER_ROW;
154-
#pragma unroll
155-
for( int ii = 0; ii < LDGS; ++ii ) {
156-
using elem2_type = typename std::conditional<std::is_same<elem_type, __half>::value, __half2, __nv_bfloat162>::type;
157-
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
158-
elem2_type *ptr_ = reinterpret_cast<elem2_type *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
159-
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
160-
#pragma unroll
161-
for (int jj = 0; jj < 4; ++jj) {
162-
atomicAdd(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
163-
// atomic_add_CAS(ptr_ + jj, reinterpret_cast<const elem2_type(&)[4]>(data[ii])[jj]);
164-
}
165-
}
166-
}
167-
}
168-
169-
// Not being used. This only supports converting from fp16 -> fp32 for now (not bf16 -> fp32).
170-
inline __device__ void atomic_add_float(const uint4 (&data)[LDGS]) {
171-
static_assert(BYTES_PER_ELEMENT == 4); // Only support fp32
172-
int row_ = tidx_ / THREADS_PER_ROW;
173-
#pragma unroll
174-
for( int ii = 0; ii < LDGS; ++ii ) {
175-
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
176-
float *ptr_ = reinterpret_cast<float *>(ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes);
177-
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
178-
#pragma unroll
179-
for (int jj = 0; jj < 4; ++jj) {
180-
const float2 data_f = fmha::half2_unpack<__half>(reinterpret_cast<const uint32_t(&)[4]>(data[ii])[jj]);
181-
atomicAdd(ptr_ + jj * 2, data_f.x);
182-
atomicAdd(ptr_ + jj * 2 + 1, data_f.y);
183-
}
184-
}
185-
}
186-
}
187-
188137
inline __device__ void move(const int steps = 1) {
189138
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
190139
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
@@ -306,6 +255,27 @@ struct Gmem_tile_o {
306255
}
307256
}
308257

258+
// Store data to global memory with atomicAdd.
259+
inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) {
260+
static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats
261+
int row_ = tidx_ / THREADS_PER_ROW;
262+
#pragma unroll
263+
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
264+
int jj = mi * STGS_PER_LOOP + ii;
265+
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
266+
break;
267+
}
268+
269+
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
270+
float *ptr_ = reinterpret_cast<float *>(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
271+
#pragma unroll
272+
for (int jj = 0; jj < 4; ++jj) {
273+
atomicAdd(ptr_ + jj, reinterpret_cast<const float(&)[4]>(src[ii])[jj]);
274+
}
275+
}
276+
}
277+
}
278+
309279
// Load data from global memory.
310280
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
311281
static_assert(BYTES_PER_ELEMENT == 4);
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Copyright (c) 2022, Tri Dao.
2+
3+
// Splitting the different head dimensions to different files to speed up compilation.
4+
5+
#include "fmha_bwd_launch_template.h"
6+
7+
void run_fmha_bwd_hdim128(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
8+
FP16_SWITCH(params.is_bf16, ([&] {
9+
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
10+
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
11+
}));
12+
}
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) 2022, Tri Dao.
2+
3+
// Splitting the different head dimensions to different files to speed up compilation.
4+
5+
#include "fmha_bwd_launch_template.h"
6+
7+
void run_fmha_bwd_hdim32(FMHA_dgrad_params &params, cudaStream_t stream, const bool configure) {
8+
FP16_SWITCH(params.is_bf16, ([&] {
9+
if (params.seqlen_k == 128) {
10+
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
11+
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
12+
} else if (params.seqlen_k >= 256) {
13+
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
14+
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
15+
}
16+
}));
17+
}

0 commit comments

Comments
 (0)