|
| 1 | +import os.path |
| 2 | +import logging |
| 3 | +import argparse |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +from datetime import datetime |
| 7 | +from collections import OrderedDict |
| 8 | + |
| 9 | +import torch |
| 10 | + |
| 11 | +from utils import utils_logger |
| 12 | +from utils import utils_model |
| 13 | +from utils import utils_image as util |
| 14 | + |
| 15 | + |
| 16 | +''' |
| 17 | +% If you have any question, please feel free to contact with me. |
| 18 | +% Kai Zhang (e-mail: cskaizhang@gmail.com; github: https://github.com/cszn) |
| 19 | +by Kai Zhang (2021/05-2021/11) |
| 20 | +''' |
| 21 | + |
| 22 | + |
| 23 | +def main(): |
| 24 | + |
| 25 | + # ---------------------------------------- |
| 26 | + # Preparation |
| 27 | + # ---------------------------------------- |
| 28 | + parser = argparse.ArgumentParser() |
| 29 | + parser.add_argument('--model_name', type=str, default='scunet_color_25', help='scunet_color_15, scunet_color_25, scunet_color_50') |
| 30 | + parser.add_argument('--testset_name', type=str, default='kodak24', help='test set, bsd68 | set12') |
| 31 | + parser.add_argument('--noise_level_img', type=int, default=25, help='noise level: 15, 25, 50') |
| 32 | + parser.add_argument('--x8', type=bool, default=False, help='x8 to boost performance') |
| 33 | + parser.add_argument('--show_img', type=bool, default=False, help='show the image') |
| 34 | + parser.add_argument('--model_zoo', type=str, default='model_zoo', help='path of model_zoo') |
| 35 | + parser.add_argument('--testsets', type=str, default='testsets', help='path of testing folder') |
| 36 | + parser.add_argument('--results', type=str, default='results', help='path of results') |
| 37 | + parser.add_argument('--need_degradation', type=bool, default=True, help='add noise or not') |
| 38 | + |
| 39 | + args = parser.parse_args() |
| 40 | + |
| 41 | + n_channels = 3 # fixed, 1 for grayscale image, 3 for color image |
| 42 | + |
| 43 | + result_name = args.testset_name + '_' + args.model_name # fixed |
| 44 | + border = 0 # shave boader to calculate PSNR and SSIM |
| 45 | + model_path = os.path.join(args.model_zoo, args.model_name+'.pth') |
| 46 | + |
| 47 | + # ---------------------------------------- |
| 48 | + # L_path, E_path, H_path |
| 49 | + # ---------------------------------------- |
| 50 | + L_path = os.path.join(args.testsets, args.testset_name) # L_path, for Low-quality images |
| 51 | + H_path = L_path # H_path, for High-quality images |
| 52 | + E_path = os.path.join(args.results, result_name) # E_path, for Estimated images |
| 53 | + util.mkdir(E_path) |
| 54 | + |
| 55 | + if H_path == L_path: |
| 56 | + args.need_degradation = True |
| 57 | + logger_name = result_name |
| 58 | + utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) |
| 59 | + logger = logging.getLogger(logger_name) |
| 60 | + |
| 61 | + need_H = True if H_path is not None else False |
| 62 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 63 | + |
| 64 | + # ---------------------------------------- |
| 65 | + # load model |
| 66 | + # ---------------------------------------- |
| 67 | + from models.network_scunet import SCUNet as net |
| 68 | + model = net(in_nc=n_channels,config=[4,4,4,4,4,4,4],dim=64) |
| 69 | + |
| 70 | + model.load_state_dict(torch.load(model_path), strict=True) |
| 71 | + model.eval() |
| 72 | + for k, v in model.named_parameters(): |
| 73 | + v.requires_grad = False |
| 74 | + model = model.to(device) |
| 75 | + logger.info('Model path: {:s}'.format(model_path)) |
| 76 | + number_parameters = sum(map(lambda x: x.numel(), model.parameters())) |
| 77 | + logger.info('Params number: {}'.format(number_parameters)) |
| 78 | + |
| 79 | + test_results = OrderedDict() |
| 80 | + test_results['psnr'] = [] |
| 81 | + test_results['ssim'] = [] |
| 82 | + |
| 83 | + logger.info('model_name:{}, image sigma:{}'.format(args.model_name, args.noise_level_img)) |
| 84 | + logger.info(L_path) |
| 85 | + L_paths = util.get_image_paths(L_path) |
| 86 | + H_paths = util.get_image_paths(H_path) if need_H else None |
| 87 | + |
| 88 | + num_parameters = sum(map(lambda x: x.numel(), model.parameters())) |
| 89 | + logger.info('{:>16s} : {:<.4f} [M]'.format('#Params', num_parameters/10**6)) |
| 90 | + |
| 91 | + for idx, img in enumerate(L_paths): |
| 92 | + |
| 93 | + # ------------------------------------ |
| 94 | + # (1) img_L |
| 95 | + # ------------------------------------ |
| 96 | + img_name, ext = os.path.splitext(os.path.basename(img)) |
| 97 | + # logger.info('{:->4d}--> {:>10s}'.format(idx+1, img_name+ext)) |
| 98 | + img_L = util.imread_uint(img, n_channels=n_channels) |
| 99 | + img_L = util.uint2single(img_L) |
| 100 | + |
| 101 | + if args.need_degradation: # degradation process |
| 102 | + np.random.seed(seed=0) # for reproducibility |
| 103 | + img_L += np.random.normal(0, args.noise_level_img/255., img_L.shape) |
| 104 | + |
| 105 | + util.imshow(util.single2uint(img_L), title='Noisy image with noise level {}'.format(args.noise_level_img)) if args.show_img else None |
| 106 | + |
| 107 | + img_L = util.single2tensor4(img_L) |
| 108 | + img_L = img_L.to(device) |
| 109 | + |
| 110 | + # ------------------------------------ |
| 111 | + # (2) img_E |
| 112 | + # ------------------------------------ |
| 113 | + #img_E = utils_model.test_mode(model, img_L, mode=2, refield=64) |
| 114 | + x8 = args.x8 |
| 115 | + if not x8 and img_L.size(2)//8==0 and img_L.size(3)//8==0: |
| 116 | + img_E = model(img_L) |
| 117 | + elif not x8 and (img_L.size(2)//8!=0 or img_L.size(3)//8!=0): |
| 118 | + img_E = utils_model.test_mode(model, img_L, refield=64, mode=5) |
| 119 | + elif x8: |
| 120 | + img_E = utils_model.test_mode(model, img_L, mode=3) |
| 121 | + |
| 122 | + #img_E = model(img_L) |
| 123 | + |
| 124 | + img_E = util.tensor2uint(img_E) |
| 125 | + |
| 126 | + if need_H: |
| 127 | + |
| 128 | + # -------------------------------- |
| 129 | + # (3) img_H |
| 130 | + # -------------------------------- |
| 131 | + img_H = util.imread_uint(H_paths[idx], n_channels=n_channels) |
| 132 | + img_H = img_H.squeeze() |
| 133 | + |
| 134 | + # -------------------------------- |
| 135 | + # PSNR and SSIM |
| 136 | + # -------------------------------- |
| 137 | + psnr = util.calculate_psnr(img_E, img_H, border=border) |
| 138 | + ssim = util.calculate_ssim(img_E, img_H, border=border) |
| 139 | + test_results['psnr'].append(psnr) |
| 140 | + test_results['ssim'].append(ssim) |
| 141 | + logger.info('{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.'.format(img_name+ext, psnr, ssim)) |
| 142 | + util.imshow(np.concatenate([img_E, img_H], axis=1), title='Recovered / Ground-truth') if args.show_img else None |
| 143 | + |
| 144 | + # ------------------------------------ |
| 145 | + # save results |
| 146 | + # ------------------------------------ |
| 147 | + util.imsave(img_E, os.path.join(E_path, img_name+ext)) |
| 148 | + |
| 149 | + if need_H: |
| 150 | + ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) |
| 151 | + ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) |
| 152 | + logger.info('Average PSNR/SSIM(RGB) - {} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(result_name, ave_psnr, ave_ssim)) |
| 153 | + |
| 154 | +if __name__ == '__main__': |
| 155 | + |
| 156 | + main() |
0 commit comments