1. Pix2pix的简介:
图像处理中的很多问题都是将一张输入的图片转变成一张对应的输出图像,比如将一张灰度图转换为一张彩色图,将一张素描图转换为一张实物图, 这类问题的本质上是像素到像素的映射。2017年的CVPR上发表了一篇文章提出了一种基于GAN的Pix2pix网络来解决这类问题,pix2pix可以实现两个领域中匹配图像直接的转换,而且所得的结果比较清晰。
2. Pix2pix的网络结构:
该结构中生成器G的输入为Img_A,大小为(batch_size, A_channel, cols, rows), 输出为Img_B,大小为(batch_size, B_channel, cols, rows)。判别器D的输入为Img_A和Img_B的图像对,需要将两个图像在channel的维度上进行拼接,因此判别器输入数据的尺寸为(batch_size, A_channel + B_channel, cols, rows),判别器的输出为(batch_size, 1, s1, s2)。
3. Pix2pix的损失函数
pix2pix使用的是cGAN结构,除了cGAN的基本损失函数,生成器还增加了一个像素损失。
(1)生成器的损失函数:生成器的损失函数由对抗损失和像素损失构成。
对抗损失:
像素损失:
生成器总的损失为:
fake_B = generator(real_A)
pred_fake = discriminator(fake_B, real_A)
loss_GAN = torch.nn.MSELoss(pred_fake, valid) # 对抗损失
loss_pixel = torch.nn.L1Loss(fake_B, real_B) # 像素损失
loos_G = loss_GAN + lambda * loss_pixel # 生成器的总损失
(2)判别器的损失函数: pix2pix中判别器的损失与cGAN相同。
判别器总的损失为:
pred_real = discriminator(real_A, real_B)
loss_real = torch.nn.MSELoss(pred_loss, valid) # 真实图像对的损失
pred_fake = discriminator(fake_B.detach(), real_A)
loss_fake = torch.nn.MSELoss(pred_fake, fake) # 生成图像对的损失
loss_D = (loss_real + loss_fake) / 2 # 判别器的总损失