15
15
from megatron .model .fused_bias_gelu import bias_gelu_impl
16
16
from megatron .model .utils import attention_mask_func , openai_gelu , erf_gelu
17
17
18
+ try :
19
+ from einops import rearrange
20
+ except ImportError :
21
+ rearrange = None
22
+
23
+ try :
24
+ from flash_attn .flash_attn_interface import flash_attn_unpadded_func
25
+ except ImportError :
26
+ flash_attn_unpadded_func = None
27
+
18
28
""" We use the following notation throughout this file:
19
29
h: hidden size
20
30
n: number of attention heads
@@ -305,6 +315,48 @@ def forward(self, query_layer, key_layer,
305
315
return context_layer
306
316
307
317
318
+ class FlashSelfAttention (torch .nn .Module ):
319
+ """Implement the scaled dot product attention with softmax.
320
+ Arguments
321
+ ---------
322
+ softmax_scale: The temperature to use for the softmax attention.
323
+ (default: 1/sqrt(d_keys) where d_keys is computed at
324
+ runtime)
325
+ attention_dropout: The dropout rate to apply to the attention
326
+ (default: 0.0)
327
+ """
328
+ def __init__ (self , causal = False , softmax_scale = None , attention_dropout = 0.0 ,
329
+ device = None , dtype = None ):
330
+ super ().__init__ ()
331
+ assert flash_attn_unpadded_func is not None , ('Please install FlashAttention first, '
332
+ 'e.g., with pip install flash-attn' )
333
+ assert rearrange is not None , 'Please install einops first, e.g., with pip install einops'
334
+ self .causal = causal
335
+ self .softmax_scale = softmax_scale
336
+ self .dropout_p = attention_dropout
337
+
338
+ def forward (self , q , k , v ):
339
+ """Implements the multihead softmax attention.
340
+ Arguments
341
+ ---------
342
+ q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
343
+ """
344
+ assert q .dtype in [torch .float16 , torch .bfloat16 ]
345
+ assert q .is_cuda
346
+ batch_size , seqlen = q .shape [0 ], q .shape [1 ]
347
+ q , k , v = [rearrange (x , 'b s ... -> (b s) ...' ) for x in [q , k , v ]]
348
+ max_s = seqlen
349
+ cu_seqlens = torch .arange (0 , (batch_size + 1 ) * seqlen , step = seqlen , dtype = torch .int32 ,
350
+ device = q .device )
351
+ output = flash_attn_unpadded_func (
352
+ q , k , v , cu_seqlens , cu_seqlens , max_s , max_s ,
353
+ self .dropout_p if self .training else 0.0 ,
354
+ softmax_scale = self .softmax_scale , causal = self .causal
355
+ )
356
+ output = rearrange (output , '(b s) ... -> b s ...' , b = batch_size )
357
+ return output
358
+
359
+
308
360
class ParallelAttention (MegatronModule ):
309
361
"""Parallel self-attention layer abstract class.
310
362
@@ -322,6 +374,19 @@ def __init__(self, init_method,
322
374
self .attention_type = attention_type
323
375
self .attn_mask_type = attn_mask_type
324
376
self .params_dtype = args .params_dtype
377
+ self .sequence_parallel = args .sequence_parallel
378
+
379
+ self .use_flash_attn = args .use_flash_attn
380
+ if self .use_flash_attn :
381
+ if flash_attn_unpadded_func is None :
382
+ raise ImportError ('FlashAttention is not installed, please install with '
383
+ 'pip install flash-attn' )
384
+ assert attention_type == AttnType .self_attn , ('FlashAttention code path only supports '
385
+ 'self-attention for now' )
386
+ assert self .attn_mask_type == AttnMaskType .causal , ('FlashAttention code path only '
387
+ 'supports causal mask for now' )
388
+ if rearrange is None :
389
+ raise ImportError ('einops is not installed, please install with pip install einops' )
325
390
326
391
projection_size = args .kv_channels * args .num_attention_heads
327
392
@@ -364,6 +429,11 @@ def __init__(self, init_method,
364
429
self .attn_mask_type )
365
430
self .checkpoint_core_attention = args .recompute_granularity == 'selective'
366
431
432
+ if self .use_flash_attn :
433
+ self .core_attention_flash = FlashSelfAttention (
434
+ causal = True , attention_dropout = args .attention_dropout
435
+ )
436
+
367
437
# Output.
368
438
self .dense = tensor_parallel .RowParallelLinear (
369
439
projection_size ,
@@ -486,12 +556,22 @@ def forward(self, hidden_states, attention_mask,
486
556
# core attention computation
487
557
# ==================================
488
558
489
- if self .checkpoint_core_attention :
490
- context_layer = self ._checkpointed_attention_forward (
491
- query_layer , key_layer , value_layer , attention_mask )
559
+ if not self .use_flash_attn :
560
+ if self .checkpoint_core_attention :
561
+ context_layer = self ._checkpointed_attention_forward (
562
+ query_layer , key_layer , value_layer , attention_mask )
563
+ else :
564
+ context_layer = self .core_attention (
565
+ query_layer , key_layer , value_layer , attention_mask )
492
566
else :
493
- context_layer = self .core_attention (
494
- query_layer , key_layer , value_layer , attention_mask )
567
+ q , k , v = [rearrange (x , 's b ... -> b s ...' ).contiguous ()
568
+ for x in (query_layer , key_layer , value_layer )]
569
+ if not self .sequence_parallel :
570
+ with tensor_parallel .get_cuda_rng_tracker ().fork ():
571
+ context_layer = self .core_attention_flash (q , k , v )
572
+ else :
573
+ context_layer = self .core_attention_flash (q , k , v )
574
+ context_layer = rearrange (context_layer , 'b s h d -> s b (h d)' ).contiguous ()
495
575
496
576
# =================
497
577
# Output. [sq, b, h]
0 commit comments