-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathrun.py
99 lines (77 loc) · 2.97 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import random
from typing import List
import click
import numpy as np
# seeds
import torch
from indad.data import MVTEC_CLASSES, MVTecDataset
from indad.models import SPADE, PaDiM, PatchCore
from indad.utils import print_and_export_results
import warnings # for some torch warnings regarding depreciation
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
warnings.filterwarnings("ignore")
ALL_CLASSES = MVTEC_CLASSES.keys()
ALLOWED_METHODS = ["spade", "padim", "patchcore"]
def run_model(method: str, classes: List[str], backbone: str):
results = {}
for class_name in classes:
if method == "spade":
model = SPADE(
k=50,
backbone_name=backbone,
)
elif method == "padim":
model = PaDiM(
d_reduced=350,
backbone_name=backbone,
)
elif method == "patchcore":
model = PatchCore(
f_coreset=0.10,
backbone_name=backbone,
)
print(f"\n█│ Running {method} on {class_name} dataset.")
print(f" ╰{'─'*(len(method)+len(class_name)+23)}\n")
train_ds, test_ds = MVTecDataset(class_name).get_dataloaders()
print(" Training ...")
model.fit(train_ds)
print(" Testing ...")
image_rocauc, pixel_rocauc = model.evaluate(test_ds)
print(f"\n ╭{'─'*(len(class_name)+15)}┬{'─'*20}┬{'─'*20}╮")
print(
f" │ Test results {class_name} │ image_rocauc: {image_rocauc:.2f} │ pixel_rocauc: {pixel_rocauc:.2f} │"
)
print(f" ╰{'─'*(len(class_name)+15)}┴{'─'*20}┴{'─'*20}╯")
results[class_name] = [float(image_rocauc), float(pixel_rocauc)]
image_results = [v[0] for _, v in results.items()]
average_image_roc_auc = sum(image_results) / len(image_results)
image_results = [v[1] for _, v in results.items()]
average_pixel_roc_auc = sum(image_results) / len(image_results)
total_results = {
"per_class_results": results,
"average image rocauc": average_image_roc_auc,
"average pixel rocauc": average_pixel_roc_auc,
"model parameters": model.get_parameters(),
}
return total_results
@click.command()
@click.argument("method")
@click.option(
"--dataset", default="all", help="Dataset name, defaults to all datasets."
)
@click.option(
"--backbone", default="wide_resnet50_2", help="The TIMM compatible backbone."
)
def cli_interface(method: str, dataset: str, backbone: str):
if dataset == "all":
dataset = ALL_CLASSES
else:
dataset = [dataset]
method = method.lower()
assert method in ALLOWED_METHODS, f"Select from {ALLOWED_METHODS}."
total_results = run_model(method, dataset, backbone)
print_and_export_results(total_results, method)
if __name__ == "__main__":
cli_interface()