Skip to content

Commit 285068c

Browse files
committed
Merge branch 'tridao-flashattn' into 'main'
Integrate FlashAttention into Megatron-LM See merge request ADLR/megatron-lm!488
2 parents b707199 + c92f10b commit 285068c

File tree

3 files changed

+100
-5
lines changed

3 files changed

+100
-5
lines changed

README.md

+12
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,18 @@ Theoretical memory savings vary depending on the combination of the model's para
333333
| bf16 param, fp32 grads | 18 | 6 + 12/d |
334334
| fp32 param, fp32 grads | 16 | 8 + 8/d |
335335

336+
## FlashAttention
337+
338+
Usage: `--use-flash-attn`. Support attention head dimensions at most 128.
339+
340+
[FlashAttention](https://github.com/HazyResearch/flash-attention) is a fast and
341+
memory-efficient algorithm to compute exact attention. It speeds up model
342+
training and reduces memory requirement.
343+
344+
To install FlashAttention:
345+
```sh
346+
pip install flash-attn
347+
```
336348

337349
## GPT-3 Example
338350

megatron/arguments.py

+3
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ def _add_training_args(parser):
652652
group.add_argument('--no-bias-dropout-fusion', action='store_false',
653653
help='Disable bias and dropout fusion.',
654654
dest='bias_dropout_fusion')
655+
group.add_argument('--use-flash-attn', action='store_true',
656+
help='use FlashAttention implementation of attention. '
657+
'https://arxiv.org/abs/2205.14135')
655658
group.add_argument('--optimizer', type=str, default='adam',
656659
choices=['adam', 'sgd'],
657660
help='Optimizer function')

megatron/model/transformer.py

+85-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
from megatron.model.fused_bias_gelu import bias_gelu_impl
1616
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
1717

18+
try:
19+
from einops import rearrange
20+
except ImportError:
21+
rearrange = None
22+
23+
try:
24+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
25+
except ImportError:
26+
flash_attn_unpadded_func = None
27+
1828
""" We use the following notation throughout this file:
1929
h: hidden size
2030
n: number of attention heads
@@ -305,6 +315,48 @@ def forward(self, query_layer, key_layer,
305315
return context_layer
306316

307317

318+
class FlashSelfAttention(torch.nn.Module):
319+
"""Implement the scaled dot product attention with softmax.
320+
Arguments
321+
---------
322+
softmax_scale: The temperature to use for the softmax attention.
323+
(default: 1/sqrt(d_keys) where d_keys is computed at
324+
runtime)
325+
attention_dropout: The dropout rate to apply to the attention
326+
(default: 0.0)
327+
"""
328+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
329+
device=None, dtype=None):
330+
super().__init__()
331+
assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, '
332+
'e.g., with pip install flash-attn')
333+
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
334+
self.causal = causal
335+
self.softmax_scale = softmax_scale
336+
self.dropout_p = attention_dropout
337+
338+
def forward(self, q, k, v):
339+
"""Implements the multihead softmax attention.
340+
Arguments
341+
---------
342+
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
343+
"""
344+
assert q.dtype in [torch.float16, torch.bfloat16]
345+
assert q.is_cuda
346+
batch_size, seqlen = q.shape[0], q.shape[1]
347+
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
348+
max_s = seqlen
349+
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
350+
device=q.device)
351+
output = flash_attn_unpadded_func(
352+
q, k, v, cu_seqlens, cu_seqlens, max_s, max_s,
353+
self.dropout_p if self.training else 0.0,
354+
softmax_scale=self.softmax_scale, causal=self.causal
355+
)
356+
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
357+
return output
358+
359+
308360
class ParallelAttention(MegatronModule):
309361
"""Parallel self-attention layer abstract class.
310362
@@ -322,6 +374,19 @@ def __init__(self, init_method,
322374
self.attention_type = attention_type
323375
self.attn_mask_type = attn_mask_type
324376
self.params_dtype = args.params_dtype
377+
self.sequence_parallel = args.sequence_parallel
378+
379+
self.use_flash_attn = args.use_flash_attn
380+
if self.use_flash_attn:
381+
if flash_attn_unpadded_func is None:
382+
raise ImportError('FlashAttention is not installed, please install with '
383+
'pip install flash-attn')
384+
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
385+
'self-attention for now')
386+
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
387+
'supports causal mask for now')
388+
if rearrange is None:
389+
raise ImportError('einops is not installed, please install with pip install einops')
325390

326391
projection_size = args.kv_channels * args.num_attention_heads
327392

@@ -364,6 +429,11 @@ def __init__(self, init_method,
364429
self.attn_mask_type)
365430
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
366431

432+
if self.use_flash_attn:
433+
self.core_attention_flash = FlashSelfAttention(
434+
causal=True, attention_dropout=args.attention_dropout
435+
)
436+
367437
# Output.
368438
self.dense = tensor_parallel.RowParallelLinear(
369439
projection_size,
@@ -486,12 +556,22 @@ def forward(self, hidden_states, attention_mask,
486556
# core attention computation
487557
# ==================================
488558

489-
if self.checkpoint_core_attention:
490-
context_layer = self._checkpointed_attention_forward(
491-
query_layer, key_layer, value_layer, attention_mask)
559+
if not self.use_flash_attn:
560+
if self.checkpoint_core_attention:
561+
context_layer = self._checkpointed_attention_forward(
562+
query_layer, key_layer, value_layer, attention_mask)
563+
else:
564+
context_layer = self.core_attention(
565+
query_layer, key_layer, value_layer, attention_mask)
492566
else:
493-
context_layer = self.core_attention(
494-
query_layer, key_layer, value_layer, attention_mask)
567+
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
568+
for x in (query_layer, key_layer, value_layer)]
569+
if not self.sequence_parallel:
570+
with tensor_parallel.get_cuda_rng_tracker().fork():
571+
context_layer = self.core_attention_flash(q, k, v)
572+
else:
573+
context_layer = self.core_attention_flash(q, k, v)
574+
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
495575

496576
# =================
497577
# Output. [sq, b, h]

0 commit comments

Comments
 (0)