这里使用InstanceNorm校正的色偏与白平衡还是有很大差距的。
情况1:图片标准化使用全局最大最小值计算。
import torchvision
from torchvision import transforms as T
from torch.autograd import Variable
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import numpy as np
class modelfunc(nn.Module):
# 之前定义好的模型
def __init__(self, class_num):
super(modelfunc, self).__init__()
self.IN = nn.InstanceNorm2d(3)
def mySpatiaNorm(self, x):
b, c, w, h = x.shape
x = x.view(1, 3, -1)
x = F.normalize(x, p=2, dim=2)
x = x.view(b, c, w, h)
return x
def forward(self,x):
### func
x1 = self.IN(x)
# use global max and min
x1 = (x1 - x1.min()) / (x1.max() - x1.min())
### func
x2 = self.mySpatiaNorm(x)
# use global max and min
x2 = (x2 - x2.min()) / (x2.max() - x2.min())
return x1, x2
# 模型实例化
model_object = modelfunc(3) # 导入模型结构
# load图像
pil_img = Image.open('123.jpg')
tensor_img = T.ToTensor()(pil_img).unsqueeze(0)
t_img1, t_img2 = model_object(tensor_img)
pil_img_new1 = T.ToPILImage()(t_img1.squeeze(0))
pil_img_new2 = T.ToPILImage()(t_img2.squeeze(0))
# 画图
fig = plt.figure(figsize=(12,5))
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)
ax1.set_title('Origin Img')
ax1.imshow(pil_img)
ax2.set_title('torch IN')
ax2.imshow(pil_img_new1)
ax3.set_title('my IN')
ax3.imshow(pil_img_new2)
# plt.imshow(pil_img_new)
plt.show()
实验结果1
情况2:图片标准化使用特定通道最大最小值分别计算。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms as T
from torch.autograd import Variable
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import numpy as np
class modelfunc(nn.Module):
# 之前定义好的模型
def __init__(self, class_num):
super(modelfunc, self).__init__()
self.IN = nn.InstanceNorm2d(3)
def myIN(self, x):
b, c, w, h = x.shape
x = x.view(1, 3, -1)
x = F.normalize(x, p=2, dim=2)
x = x.view(b, c, w, h)
return x
def forward(self,x):
### func ###
x1 = self.IN(x)
# compute max and min in specific channel and normalize them
t_list = []
for img_t in x1:
xxmin = torch.stack([img_t[0].min(), img_t[1].min(), img_t[2].min()]).view(-1, 1, 1)
xxmax = torch.stack([img_t[0].max(), img_t[1].max(), img_t[2].max()]).view(-1, 1, 1)
img_t = (img_t - xxmin) * (1 / (xxmax - xxmin) )
t_list.append(img_t)
x1 = torch.stack(t_list)
# x1 = (x1 - x1.min()) / (x1.max() - x1.min())
### func ###
x2 = self.myIN(x)
# compute max and min in specific channel and normalize them
t_list = []
for img_t in x2:
xxmin = torch.stack([img_t[0].min(), img_t[1].min(), img_t[2].min()]).view(-1, 1, 1)
xxmax = torch.stack([img_t[0].max(), img_t[1].max(), img_t[2].max()]).view(-1, 1, 1)
img_t = (img_t - xxmin) * (1 / (xxmax - xxmin))
t_list.append(img_t)
x2 = torch.stack(t_list)
# x2 = (x2 - x2.min()) / (x2.max() - x2.min())
return x1, x2
# 模型实例化
model_object = modelfunc(3) # 导入模型结构
# load图像
pil_img = Image.open('123.jpg')
tensor_img = T.ToTensor()(pil_img).unsqueeze(0)
t_img1, t_img2 = model_object(tensor_img)
pil_img_new1 = T.ToPILImage()(t_img1.squeeze(0))
pil_img_new2 = T.ToPILImage()(t_img2.squeeze(0))
# 画图
fig = plt.figure(figsize=(12,5))
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)
ax1.set_title('Origin Img')
ax1.imshow(pil_img)
ax2.set_title('torch IN')
ax2.imshow(pil_img_new1)
ax3.set_title('my IN')
ax3.imshow(pil_img_new2)
# plt.imshow(pil_img_new)
plt.show()
实验结果2