SRGAN是2017年CVPR中备受关注的超分辨率论文,把超分辨率的效果带到了一个新的高度。所谓超分辨率重建就是将低分辨率的图像恢复成对应的高分辨率图像。由于地分辨率图像存在大量的信息缺失,这是一个病态的求逆解问题,尤其在恢复高倍分辨率图像的时候。传统方法通过加入一些先验信息来恢复高分辨率图像,如,插值法、稀疏学习、还有基于回归方法的随机森林等,CNN在超分辨率问题上取得了非常好的效果。
SRGAN是基于CNN采用GAN方法进行训练来实现图像的超分辨率重建的。它包含一个生成器和一个判别器,判别器的主体是VGG19,生成器的主体是一连串的Residual block,同时在模型的后部加入了subpixel模块,借鉴了Shi et al 的Subpixel Network的思想,让图片在最后的网络层才增加分辨率,使得提升分别率的同时减少了计算量。论文中给出的网络结构如图所示:
论文中还给出了生成器和判别器的损失函数的形式:
1.生成器的损失函数为:
其中,为本文所提出的感知损失函数,。
内容损失:; 训练网络时使用均方差损失可以获得较高的峰值信噪比,一般的超分辨率重建方法中,内容损失都选择使用生成图像和目标图像的均方差损失(MSELoss),但是使用均方差损失恢复的图像会丢失很多高频细节。因此,本文先将生成图像和目标图像分别输入到VGG网络中,然后对他们经过VGG后得到的feature map求欧式距离,并将其作为VGG loss。
对抗损失:; 为了避免当判别器训练较好时生成器出现梯度消失,本文将生成器的损失函数 进行了修改。
gen_hr = generator(img_lr) ## img_lr 为输入的地分辨率图像
Loss_GAN = torch.BCELoss(discriminator(gen_hr), valid) ## gen_hr 为生成的高分辨率图像
gen_features = VGG_feature_extract(gen_hr)
real_features = VGG_feature_extract(img_hr) ## img_hr 为输入的目标高分辨率图像
Loss_content = torch.nn.L2Loss(gen_feature, real_feature)
Loss_G = Loss_content + 1e-3 * Loss_GAN
2.判别器的损失函数为:
与普通的生成对抗网络判别器的的损失函数类似。
Loss_real = torch.BCELoss(disciminator(img_hr), valid)
Loss_fake = torch.BCELoss(discriminator(generator(gen_hr)), fake)
Loss_D = (Loss_real + Loss_fake) / 2