-
Notifications
You must be signed in to change notification settings - Fork 2
/
filter_bs.py
70 lines (64 loc) · 2.76 KB
/
filter_bs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
from threading import currentThread
# {model_name: {}}
results = {}
def work(file_path):
fin = open(file_path)
model_list = json.load(fin)
fin.close()
fout = open("proper_bs.csv", 'w')
MAX_EXP = 30
head_str = ''
for i in range(MAX_EXP):
head_str += f"{2**i}, "
fout.write(f"model, {head_str} best_tflops_bs\n")
for model in model_list:
fout.write("%s, " % model['name'])
latencies = {}
tflops = {}
tflops_gaps = [0] * MAX_EXP
if model['results']['details']:
for bs in model['results']['details']:
# latencies[int(bs['batch_size'])] = float(bs['latency_ms'])
tflops[int(bs['batch_size'])] = float(bs['tflops'])
# if latencies:
# for batch_size_exp in range(8):
# batch_size = 2 ** batch_size_exp
# fout.write("%.2f, " % latencies.get(batch_size, 0))
# best_latency_bs = min(latencies, key=latencies.get)
# fout.write("%d, " % best_latency_bs)
if tflops:
for batch_size_exp in range(MAX_EXP):
batch_size = 2 ** batch_size_exp
last_batch_size = 2 ** (batch_size_exp-1)
current_tflops = tflops.get(batch_size, 0)
fout.write("%.2f, " % current_tflops)
tflops_gaps[batch_size_exp] = tflops.get(
batch_size, 0) - tflops.get(last_batch_size, 0)
max_tflops = max(tflops.values())
best_tflops = 0
best_tflops_bs = 0
last_tflops = 0
special = False
for batch_size_exp in range(MAX_EXP):
batch_size = 2 ** batch_size_exp
current_tflops = tflops.get(batch_size, 0)
if current_tflops != 0 and last_tflops - current_tflops >= 0.05:
special = True
break
last_tflops = current_tflops
for batch_size_exp in range(MAX_EXP):
batch_size = 2 ** batch_size_exp
current_tflops = tflops.get(batch_size, 0)
if batch_size_exp != MAX_EXP-1:
next_tflops = tflops.get(2**(batch_size_exp+1), 0)
if current_tflops >= max_tflops*0.99:
best_tflops = current_tflops
best_tflops_bs = batch_size
break
fout.write("%d, " % best_tflops_bs)
if special:
fout.write("special ")
fout.write("\n")
fout.close()
work("/home/yhao/d/ml_optimizations/tb-output-eval-full.json")