Skip to content

Commit 3ff9045

Browse files
authoredMar 23, 2022
Add files via upload
1 parent a8311be commit 3ff9045

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed
 

‎main_test_scunet_color_gaussian.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import os.path
2+
import logging
3+
import argparse
4+
5+
import numpy as np
6+
from datetime import datetime
7+
from collections import OrderedDict
8+
9+
import torch
10+
11+
from utils import utils_logger
12+
from utils import utils_model
13+
from utils import utils_image as util
14+
15+
16+
'''
17+
% If you have any question, please feel free to contact with me.
18+
% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn)
19+
by Kai Zhang (2021/05-2021/11)
20+
'''
21+
22+
23+
def main():
24+
25+
# ----------------------------------------
26+
# Preparation
27+
# ----------------------------------------
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument('--model_name', type=str, default='scunet_color_25', help='scunet_color_15, scunet_color_25, scunet_color_50')
30+
parser.add_argument('--testset_name', type=str, default='kodak24', help='test set, bsd68 | set12')
31+
parser.add_argument('--noise_level_img', type=int, default=25, help='noise level: 15, 25, 50')
32+
parser.add_argument('--x8', type=bool, default=False, help='x8 to boost performance')
33+
parser.add_argument('--show_img', type=bool, default=False, help='show the image')
34+
parser.add_argument('--model_zoo', type=str, default='model_zoo', help='path of model_zoo')
35+
parser.add_argument('--testsets', type=str, default='testsets', help='path of testing folder')
36+
parser.add_argument('--results', type=str, default='results', help='path of results')
37+
parser.add_argument('--need_degradation', type=bool, default=True, help='add noise or not')
38+
39+
args = parser.parse_args()
40+
41+
n_channels = 3 # fixed, 1 for grayscale image, 3 for color image
42+
43+
result_name = args.testset_name + '_' + args.model_name # fixed
44+
border = 0 # shave boader to calculate PSNR and SSIM
45+
model_path = os.path.join(args.model_zoo, args.model_name+'.pth')
46+
47+
# ----------------------------------------
48+
# L_path, E_path, H_path
49+
# ----------------------------------------
50+
L_path = os.path.join(args.testsets, args.testset_name) # L_path, for Low-quality images
51+
H_path = L_path # H_path, for High-quality images
52+
E_path = os.path.join(args.results, result_name) # E_path, for Estimated images
53+
util.mkdir(E_path)
54+
55+
if H_path == L_path:
56+
args.need_degradation = True
57+
logger_name = result_name
58+
utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log'))
59+
logger = logging.getLogger(logger_name)
60+
61+
need_H = True if H_path is not None else False
62+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63+
64+
# ----------------------------------------
65+
# load model
66+
# ----------------------------------------
67+
from models.network_scunet import SCUNet as net
68+
model = net(in_nc=n_channels,config=[4,4,4,4,4,4,4],dim=64)
69+
70+
model.load_state_dict(torch.load(model_path), strict=True)
71+
model.eval()
72+
for k, v in model.named_parameters():
73+
v.requires_grad = False
74+
model = model.to(device)
75+
logger.info('Model path: {:s}'.format(model_path))
76+
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
77+
logger.info('Params number: {}'.format(number_parameters))
78+
79+
test_results = OrderedDict()
80+
test_results['psnr'] = []
81+
test_results['ssim'] = []
82+
83+
logger.info('model_name:{}, image sigma:{}'.format(args.model_name, args.noise_level_img))
84+
logger.info(L_path)
85+
L_paths = util.get_image_paths(L_path)
86+
H_paths = util.get_image_paths(H_path) if need_H else None
87+
88+
num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
89+
logger.info('{:>16s} : {:<.4f} [M]'.format('#Params', num_parameters/10**6))
90+
91+
for idx, img in enumerate(L_paths):
92+
93+
# ------------------------------------
94+
# (1) img_L
95+
# ------------------------------------
96+
img_name, ext = os.path.splitext(os.path.basename(img))
97+
# logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext))
98+
img_L = util.imread_uint(img, n_channels=n_channels)
99+
img_L = util.uint2single(img_L)
100+
101+
if args.need_degradation: # degradation process
102+
np.random.seed(seed=0) # for reproducibility
103+
img_L += np.random.normal(0, args.noise_level_img/255., img_L.shape)
104+
105+
util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(args.noise_level_img)) if args.show_img else None
106+
107+
img_L = util.single2tensor4(img_L)
108+
img_L = img_L.to(device)
109+
110+
# ------------------------------------
111+
# (2) img_E
112+
# ------------------------------------
113+
#img_E = utils_model.test_mode(model, img_L, mode=2, refield=64)
114+
x8 = args.x8
115+
if not x8 and img_L.size(2)//8==0 and img_L.size(3)//8==0:
116+
img_E = model(img_L)
117+
elif not x8 and (img_L.size(2)//8!=0 or img_L.size(3)//8!=0):
118+
img_E = utils_model.test_mode(model, img_L, refield=64, mode=5)
119+
elif x8:
120+
img_E = utils_model.test_mode(model, img_L, mode=3)
121+
122+
#img_E = model(img_L)
123+
124+
img_E = util.tensor2uint(img_E)
125+
126+
if need_H:
127+
128+
# --------------------------------
129+
# (3) img_H
130+
# --------------------------------
131+
img_H = util.imread_uint(H_paths[idx], n_channels=n_channels)
132+
img_H = img_H.squeeze()
133+
134+
# --------------------------------
135+
# PSNR and SSIM
136+
# --------------------------------
137+
psnr = util.calculate_psnr(img_E, img_H, border=border)
138+
ssim = util.calculate_ssim(img_E, img_H, border=border)
139+
test_results['psnr'].append(psnr)
140+
test_results['ssim'].append(ssim)
141+
logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim))
142+
util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if args.show_img else None
143+
144+
# ------------------------------------
145+
# save results
146+
# ------------------------------------
147+
util.imsave(img_E, os.path.join(E_path, img_name+ext))
148+
149+
if need_H:
150+
ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
151+
ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
152+
logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim))
153+
154+
if __name__ == '__main__':
155+
156+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.