Skip to content

Commit

Permalink
Enable flash_v3 backward (#2445)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2445

Reviewed By: xuzhao9

Differential Revision: D61924864

Pulled By: bertmaher

fbshipit-source-id: 760036820c1196a921eaff4d99bf8647e25264ee
  • Loading branch information
bertmaher authored and facebook-github-bot committed Aug 28, 2024
1 parent c0409aa commit e6251f1
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
try:
torch_lib_path = os.path.join(os.path.dirname(__file__), "lib")
with add_ld_library_path(torch_lib_path):
import flashattn_hopper_cuda
from flash_attn_interface import flash_attn_func as flash_attn_v3
except (ImportError, IOError, AttributeError):
HAS_FLASH_V3 = False
pass
Expand Down Expand Up @@ -223,9 +223,7 @@ def flash_v3(
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
fn = lambda: flashattn_hopper_cuda.fwd(
q, k, v, None, self.sm_scale, self.causal
)
fn = lambda: flash_attn_v3(q, k, v, self.sm_scale, self.causal)
return fn

@register_benchmark()
Expand Down

0 comments on commit e6251f1

Please sign in to comment.