import os
import math
import time
import datetime
from functools import reduce
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc as misc
from skimage.restoration import denoise_bilateral
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
class timer():
def __init__(self):
self.acc = 0
self.tic()
#print ("2-1-1-checkpoint")
def tic(self):
self.t0 = time.time()
#print ("2-1-2-checkpoint")
def toc(self):
return time.time() - self.t0
#print ("2-1-3-checkpoint")
def hold(self):
self.acc += self.toc()
#print ("2-1-4-checkpoint")
def release(self):
ret = self.acc
self.acc = 0
#print ("2-1-5-checkpoint")
return ret
def reset(self):
self.acc = 0
#print ("2-1-6-checkpoint")
class checkpoint():
def __init__(self, args):
self.args = args
self.ok = True
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if args.load == '.':
if args.save == '.': args.save = now
self.dir = '../experiment/' + args.save
else:
self.dir = '../experiment/' + args.load
if not os.path.exists(self.dir):
args.load = '.'
else:
self.log = torch.load(self.dir + '/psnr_log.pt')
print('Continue from epoch {}...'.format(len(self.log)))
if args.reset:
os.system('rm -rf ' + self.dir)
args.load = '.'
def _make_dir(path):
if not os.path.exists(path): os.makedirs(path)
_make_dir(self.dir)
_make_dir(self.dir + '/model')
_make_dir(self.dir + '/results')
_make_dir(self.dir + '/residuals')
_make_dir(self.dir + '/branches')
open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
self.log_file = open(self.dir + '/log.txt', open_type)
with open(self.dir + '/config.txt', open_type) as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
#print ("2-2-1-checkpoint")
def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.dir, epoch, is_best=is_best)
trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch)
torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
torch.save(
trainer.optimizer.state_dict(),
os.path.join(self.dir, 'optimizer.pt')
)
#print ("2-2-2-checkpoint")
def add_log(self, log):
self.log = torch.cat([self.log, log])
#print ("2-2-3-checkpoint")
def write_log(self, log, refresh=False):
#print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.dir + '/log.txt', 'a')
#print ("2-2-4-checkpoint")
def done(self):
self.log_file.close()
#print ("2-2-5-checkpoint")
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
label = 'SR on {}'.format(self.args.data_test)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate(self.args.scale):
plt.plot(
axis,
self.log[:, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
plt.close(fig)
#print ("2-2-5-checkpoint")
def save_results(self, filename, save_list, scale):
filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
normalized = v[0].data.mul(255 / self.args.rgb_range)
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
if ndarr.shape[-1] == 1:
ndarr = ndarr[:,:,0]
misc.imsave('{}{}.png'.format(filename, p), ndarr)
#print ("2-2-6-checkpoint")
def save_residuals(self, filename, save_list, scale):
filename = '{}/residuals/{}_x{}'.format(self.dir, filename, scale)
sr, hr = save_list[0], save_list[-1]
def _prepare(x):
normalized = x[0].data.mul(1. / self.args.rgb_range)
out = normalized.permute(1,2,0).cpu().numpy()
if out.shape[-1] == 1:
out = out[:,:,0]
return out
ndarr_sr, ndarr_hr = _prepare(sr), _prepare(hr)
out = np.abs(ndarr_hr - ndarr_sr)
misc.imsave('{}.png'.format(filename), out)
#print ("2-2-7-checkpoint")
def save_branches(self, filename, save_list, scale):
filename = '{}/branches/{}_x{}'.format(self.dir, filename, scale)
def _prepare(x, residual):
normalized = x[0].data.mul(1. / self.args.rgb_range)
if not residual:
out = normalized.permute(1,2,0).cpu().numpy()
else:
out = np.abs(normalized.permute(1,2,0).cpu().numpy())
if out.shape[-1] == 1:
out = out[:,:,0]
return out
for i, branch_output in enumerate(save_list):
ndarr = _prepare(branch_output, not (i==0))
misc.imsave('{}{}.png'.format(filename, '_branch{}'.format(i)), ndarr)
#print ("2-2-8-checkpoint")
return
def get_bilateral(tensor, rgb_range):
tensor = tensor.numpy().transpose(0,2,3,1) / rgb_range
out = np.zeros_like(tensor)
for i, t in enumerate(tensor):
out[i] = denoise_bilateral(t)
#print ("2-3-checkpoint")
return torch.Tensor(out.transpose(0,3,1,2)) * rgb_range
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
#print ("2-4-checkpoint")
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
diff = (sr - hr).data.div(rgb_range)
if benchmark:
shave = scale
if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256)
diff = diff.sum(dim=1, keepdim=True)
else:
shave = scale + 6
valid = diff[:, :, shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
#print ("2-5-checkpoint")
return -10 * math.log10(mse)
def make_optimizer(args, my_model):
trainable = filter(lambda x: x.requires_grad, my_model.parameters())
if args.optimizer == 'SGD':
optimizer_function = optim.SGD
kwargs = {'momentum': args.momentum}
elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam
kwargs = {
'betas': (args.beta1, args.beta2),
'eps': args.epsilon
}
elif args.optimizer == 'RMSprop':
optimizer_function = optim.RMSprop
kwargs = {'eps': args.epsilon}
kwargs['lr'] = args.lr
kwargs['weight_decay'] = args.weight_decay
#print ("2-6-checkpoint")
return optimizer_function(trainable, **kwargs)
def make_scheduler(args, my_optimizer):
if args.decay_type == 'step':
scheduler = lrs.StepLR(
my_optimizer,
step_size=args.lr_decay,
gamma=args.gamma
)
elif args.decay_type.find('step') >= 0:
milestones = args.decay_type.split('_')
milestones.pop(0)
milestones = list(map(lambda x: int(x), milestones))
scheduler = lrs.MultiStepLR(
my_optimizer,
milestones=milestones,
gamma=args.gamma
)
#print ("2-7-checkpoint")
return scheduler
utility.py
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...