(一)InstanceNorm校正色偏

这里使用InstanceNorm校正的色偏与白平衡还是有很大差距的。

情况1:图片标准化使用全局最大最小值计算。

import torchvision
from torchvision import transforms as T
from torch.autograd import Variable
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import numpy as np

class modelfunc(nn.Module):
    # 之前定义好的模型
    def __init__(self, class_num):
        super(modelfunc, self).__init__()
        self.IN = nn.InstanceNorm2d(3)

    def mySpatiaNorm(self, x):
        b, c, w, h = x.shape
        x = x.view(1, 3, -1)
        x = F.normalize(x, p=2, dim=2)
        x = x.view(b, c, w, h)
        return x

    def forward(self,x):

        ### func
        x1 = self.IN(x)
        # use global max and min
        x1 = (x1 - x1.min()) / (x1.max() - x1.min())

        ### func
        x2 = self.mySpatiaNorm(x)
        # use global max and min
        x2 = (x2 - x2.min()) / (x2.max() - x2.min())

        return x1, x2
# 模型实例化
model_object = modelfunc(3)  # 导入模型结构


# load图像
pil_img = Image.open('123.jpg')
tensor_img = T.ToTensor()(pil_img).unsqueeze(0)

t_img1, t_img2 = model_object(tensor_img)

pil_img_new1 = T.ToPILImage()(t_img1.squeeze(0))
pil_img_new2 = T.ToPILImage()(t_img2.squeeze(0))

# 画图
fig = plt.figure(figsize=(12,5))
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)

ax1.set_title('Origin Img')
ax1.imshow(pil_img)
ax2.set_title('torch IN')
ax2.imshow(pil_img_new1)
ax3.set_title('my IN')
ax3.imshow(pil_img_new2)

# plt.imshow(pil_img_new)
plt.show()
实验结果1

情况2:图片标准化使用特定通道最大最小值分别计算。

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms as T
from torch.autograd import Variable
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import numpy as np


class modelfunc(nn.Module):
    # 之前定义好的模型
    def __init__(self, class_num):
        super(modelfunc, self).__init__()
        self.IN = nn.InstanceNorm2d(3)

    def myIN(self, x):
        b, c, w, h = x.shape
        x = x.view(1, 3, -1)
        x = F.normalize(x, p=2, dim=2)
        x = x.view(b, c, w, h)
        return x

    def forward(self,x):

        ### func ###
        x1 = self.IN(x)

        # compute max and min in specific channel and normalize them
        t_list = []
        for img_t in x1:
            xxmin = torch.stack([img_t[0].min(), img_t[1].min(), img_t[2].min()]).view(-1, 1, 1)
            xxmax = torch.stack([img_t[0].max(), img_t[1].max(), img_t[2].max()]).view(-1, 1, 1)
            img_t = (img_t - xxmin) * (1 / (xxmax - xxmin) )
            t_list.append(img_t)
        x1 = torch.stack(t_list)
        # x1 = (x1 - x1.min()) / (x1.max() - x1.min())

        ### func ###
        x2 = self.myIN(x)

        # compute max and min in specific channel and normalize them
        t_list = []
        for img_t in x2:
            xxmin = torch.stack([img_t[0].min(), img_t[1].min(), img_t[2].min()]).view(-1, 1, 1)
            xxmax = torch.stack([img_t[0].max(), img_t[1].max(), img_t[2].max()]).view(-1, 1, 1)
            img_t = (img_t - xxmin) * (1 / (xxmax - xxmin))
            t_list.append(img_t)
        x2 = torch.stack(t_list)
        # x2 = (x2 - x2.min()) / (x2.max() - x2.min())

        return x1, x2
# 模型实例化
model_object = modelfunc(3)  # 导入模型结构


# load图像
pil_img = Image.open('123.jpg')
tensor_img = T.ToTensor()(pil_img).unsqueeze(0)

t_img1, t_img2 = model_object(tensor_img)

pil_img_new1 = T.ToPILImage()(t_img1.squeeze(0))
pil_img_new2 = T.ToPILImage()(t_img2.squeeze(0))

# 画图
fig = plt.figure(figsize=(12,5))
ax1 = fig.add_subplot(131)
ax2 = fig.add_subplot(132)
ax3 = fig.add_subplot(133)

ax1.set_title('Origin Img')
ax1.imshow(pil_img)
ax2.set_title('torch IN')
ax2.imshow(pil_img_new1)
ax3.set_title('my IN')
ax3.imshow(pil_img_new2)

# plt.imshow(pil_img_new)
plt.show()
实验结果2
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。