-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathtest.py
100 lines (74 loc) · 3.92 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import gc
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from metrics import pt_psnr, calculate_ssim, calculate_psnr
from pytorch_msssim import ssim
from utils import save_rgb
def test_model (model, language_model, lm_head, testsets, device, promptify, savepath="results/"):
model.eval()
if language_model:
language_model.eval()
lm_head.eval()
DEG_ACC = []
derain_datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800']
with torch.no_grad():
for testset in testsets:
if savepath:
dt_results_path = os.path.join(savepath, testset.name)
if not os.path.exists(dt_results_path):
os.mkdir(dt_results_path)
print (">>> Eval on", testset.name, testset.degradation, testset.deg_class)
testset_name = testset.name
test_dataloader = DataLoader(testset, batch_size=1, num_workers=4, drop_last=True, shuffle=False)
psnr_dataset = []
ssim_dataset = []
psnr_noisy = []
use_y_channel= False
if testset.name in derain_datasets:
use_y_channel = True
psnr_y_dataset = []
ssim_y_dataset = []
for idx, batch in enumerate(test_dataloader):
x = batch[0].to(device) # HQ image
y = batch[1].to(device) # LQ image
f = batch[2][0] # filename
t = [promptify(testset.degradation) for _ in range(x.shape[0])]
if language_model:
if idx < 5:
# print the input prompt for debugging
print("\tInput prompt:", t)
lm_embd = language_model(t)
lm_embd = lm_embd.to(device)
text_embd, deg_pred = lm_head (lm_embd)
x_hat = model(y, text_embd)
psnr_restore = torch.mean(pt_psnr(x, x_hat))
psnr_dataset.append(psnr_restore.item())
ssim_restore = ssim(x, x_hat, data_range=1., size_average=True)
ssim_dataset.append(ssim_restore.item())
psnr_base = torch.mean(pt_psnr(x, y))
psnr_noisy.append(psnr_base.item())
if use_y_channel:
_x_hat = np.clip(x_hat[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32)
_x = np.clip(x[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32)
_x_hat = (_x_hat*255).astype(np.uint8)
_x = (_x*255).astype(np.uint8)
psnr_y = calculate_psnr(_x, _x_hat, crop_border=0, input_order='HWC', test_y_channel=True)
ssim_y = calculate_ssim(_x, _x_hat, crop_border=0, input_order='HWC', test_y_channel=True)
psnr_y_dataset.append(psnr_y)
ssim_y_dataset.append(ssim_y)
## SAVE RESULTS
if savepath:
restored_img = np.clip(x_hat[0].permute(1,2,0).cpu().detach().numpy(), 0, 1).astype(np.float32)
img_name = f.split("/")[-1]
save_rgb (restored_img, os.path.join(dt_results_path, img_name))
print(f"{testset_name}_base", np.mean(psnr_noisy), "Total images:", len(psnr_dataset))
print(f"{testset_name}_psnr", np.mean(psnr_dataset))
print(f"{testset_name}_ssim", np.mean(ssim_dataset))
if use_y_channel:
print(f"{testset_name}_psnr-Y", np.mean(psnr_y_dataset), len(psnr_y_dataset))
print(f"{testset_name}_ssim-Y", np.mean(ssim_y_dataset))
print (); print (25 * "***")
del test_dataloader,psnr_dataset, psnr_noisy; gc.collect()
# END OF FUNCTION