-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbenchmark.py
84 lines (69 loc) · 2.06 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import time
from pathlib import Path
import torch
# from my_flash_attn.util import set_env_vars, load_and_compile_sources
# We need to import the CUDA kernels after importing torch
import my_flash_attn_cuda
def benchmark(f, *args):
iters = 1000
torch.cuda.synchronize()
t0 = time.perf_counter_ns()
for _ in range(iters):
out = f(*args)
torch.cuda.synchronize()
t1 = time.perf_counter_ns()
print(f"Time: {(t1-t0)/iters/1e3:.2f} µs")
"""
print("\nRunning PyTorch profiler...")
with torch.profiler.profile() as prof:
for i in range(iters):
out = my_flash_attn_cuda.matmul(m1, m2)
torch.cuda.synchronize()
print(prof.key_averages().table())
"""
# Setup
# set_env_vars()
# my_flash_attn_cuda = load_and_compile_sources(Path("csrc"), verbose=False)
# n = 5000
# n = 4096
n = 1024
# n = 256
# n = 32
# # Benchmark matmul
# print(f"\nBenchmarking matmul ({n}x{n})...")
# m1 = torch.randn(n, n, device="cuda")
# m2 = torch.randn(n, n, device="cuda")
# benchmark(my_flash_attn_cuda.my_matmul, m1, m2)
# Benchmark matmul
print(f"\nBenchmarking matmul cuBLAS ({n}x{n})...")
m1 = torch.randn(n, n, device="cuda")
m2 = torch.randn(n, n, device="cuda")
benchmark(my_flash_attn_cuda.my_matmul_cublas, m1, m2)
x = torch.randn(n, n, device="cuda")
# Benchmark softmax
print(f"\nBenchmarking softmax kernel 1 ({n}x{n})...")
try:
benchmark(my_flash_attn_cuda.my_softmax, x, 1)
except RuntimeError as e:
print(e)
print(f"\nBenchmarking softmax kernel 2 ({n}x{n})...")
try:
benchmark(my_flash_attn_cuda.my_softmax, x, 2)
except RuntimeError as e:
print(e)
print(f"\nBenchmarking softmax kernel 3 ({n}x{n})...")
try:
benchmark(my_flash_attn_cuda.my_softmax, x, 3)
except RuntimeError as e:
print(e)
print(f"\nBenchmarking softmax kernel 4 ({n}x{n})...")
try:
benchmark(my_flash_attn_cuda.my_softmax, x, 4)
except RuntimeError as e:
print(e)
print(f"\nBenchmarking softmax kernel 5 ({n}x{n})...")
try:
benchmark(my_flash_attn_cuda.my_softmax, x, 5)
except RuntimeError as e:
print(e)