DeepLearning-风格迁移

背景介绍

不知道大家是否用过prisma,就算没有用过,也一定看见别人用过这个软件,下面是一张这个软件得到的一个效果图


example
example

官方宣传的卖点是一秒钟让你的作品拥有名家风格,什么毕加索,梵高,都不在话下。通过这个效果再将你的照片发到朋友圈,是不是效果爆棚,简直是各种装逼界的一股清流,秒杀各种修图ps好吗。而且可以完美的掩饰掉一些瑕疵,又比ps更自然,更有逼格,是不是很棒。

这个软件将使用的方法发了一篇论文,并且这个软件在发布的时候就取得了上千万的融资,是不是瞬间感觉现在学习了知识也能成为千万富豪了。如今在这个高速发展的时代,知识付费的时代确实已经到来了,所以我们现在努力学习各种知识就是在赚钱啊,有木有。这样大家的学习的时候就能够有着更大的动力了。

这篇论文感兴趣的同学可以去查看一下,里面主要涉及的是卷积神经网络。今天这篇文章要做的是什么呢?我们希望自己能够简单的实现这个风格迁移算法,并且用自己的算法来得到新的风格图片。一想到我们放到朋友圈的照片是自己写的算法来实现的就感觉成就感爆棚,有没有。

环境配置

废话不多说,我们先来看看需要的基本配置。首先需要python环境,建议使用anaconda;然后我们使用的深度学习框架是pytorch,当然你也可以用tensorflow,具体框架的介绍可以去看看之前写的文章,需要安装pytorch和torchvision,这里查看安装帮助;同时需要一些其他的包,如果缺什么就pip安装就好。

这篇文章主要参考于pytorch的官方tutorial,感兴趣的同学可以直接移步至官方教程的地方,这篇文章我会说一些自己的理解,代码部分基本都是参考这个教程,但是我会做一些说明,力求更加清楚。

原理分析

其实要实现的东西很清晰,就是需要将两张图片融合在一起,这个时候就需要定义怎么才算融合在一起。首先需要的就是内容上是相近的,然后风格上是相似的。这样来我们就知道我们需要做的事情是什么了,我们需要计算融合图片和内容图片的相似度,或者说差异性,然后尽可能降低这个差异性;同时我们也需要计算融合图片和风格图片在风格上的差异性,然后也降低这个差异性就可以了。这样我们就能够量化我们的目标了。

对于内容的差异性我们该如何定义呢?其实我们能够很简答的想到就是两张图片每个像素点进行比较,也就是求一下差,因为简单的计算他们之间的差会有正负,所以我们可以加一个平方,使得差全部是正的,也可以加绝对值,但是数学上绝对值会破坏函数的可微性,所以大家都用平方,这个地方不理解也没关系,记住普遍都是使用平方就行了。

对于风格的差异性我们该如何定义呢?这才是一个难点。这也是这篇文章提出的创新点,引入了Gram矩阵计算风格的差异。我尽量不使用数学的语言来解释,而使用通俗的语言。
首先需要的预先知识是卷积网络的知识,这里不细讲了,不了解的同学可以看之前的卷积网络的文章。我们知道一张图片通过卷积网络之后可以的到一个特征图,Gram矩阵就是在这个特征图上面定义出来的。每个特征图的大小一般是 MxNxC 或者是 CxMxN 这种大小,这里C表示的时候厚度,放在前面和后面都可以,MxN 表示的是一个矩阵的大小,其实就是有 C 个 MxN 这样的矩阵叠在一起。

Gram矩阵是如何定义的呢?首先Gram矩阵的大小是有特征图的厚度决定的,等于 CxC,那么每一个Gram矩阵的元素,也就是 Gram(i, j) 等于多少呢?先把特征图中第 i 层和第 j 层取出来,这样就得到了两个 MxN的矩阵,然后将这两个矩阵对应元素相乘然后求和就得到了 Gram(i, j),同理 Gram 的所有元素都可以通过这个方式得到。这样 Gram 中每个元素都可以表示两层特征图的一种组合,就可以定义为它的风格。

然后风格的差异就是两幅图的 Gram 矩阵的差异,就像内容的差异的计算方法一样,计算一下这两个矩阵的差就可以量化风格的差异。

实现

以下的内容都是用pytorch实现的,如果对pytorch不熟悉的同学可以看一下我之前的pytorch介绍文章,看看官方教程,如果不想了解pytorch的同学可以用自己熟悉的框架实现这个算法,理论部分前面已经讲完了。

内容差异的loss定义

class Content_Loss(nn.Module):
    def __init__(self, target, weight):
        super(Content_Loss, self).__init__()
        self.weight = weight
        self.target = target.detach() * self.weight
        # 必须要用detach来分离出target,这时候target不再是一个Variable,这是为了动态计算梯度,否则forward会出错,不能向前传播
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        out = input.clone()
        return out

    def backward(self, retain_variabels=True):
        self.loss.backward(retain_variables=retain_variabels)
        return self.loss

其中有个变量weight,这个是表示权重,内容和风格你可以选择一个权重,比如你想风格上更像,内容上多一点差别没关系,那么内容的权重你可以定义小一点,风格的权重可以定义大一点;反之你可以把风格的权重定义小一点,内容的权重定义大一点。

风格差异的loss定义

Gram 矩阵的定义

class Gram(nn.Module):
    def __init__(self):
        super(Gram, self).__init__()

    def forward(self, input):
        a, b, c, d = input.size()
        feature = input.view(a * b, c * d)
        gram = torch.mm(feature, feature.t())
        gram /= (a * b * c * d)
        return gram

style loss定义

class Style_Loss(nn.Module):
    def __init__(self, target, weight):
        super(Style_Loss, self).__init__()
        self.weight = weight
        self.target = target.detach() * self.weight
        self.gram = Gram()
        self.criterion = nn.MSELoss()

    def forward(self, input):
        G = self.gram(input) * self.weight
        self.loss = self.criterion(G, self.target)
        out = input.clone()
        return out

    def backward(self, retain_variabels=True):
        self.loss.backward(retain_variables=retain_variabels)
        return self.loss

建立模型

使用19层的 vgg 作为提取特征的卷积网络,并且定义哪几层为需要的特征。


vgg = models.vgg19(pretrained=True).features
vgg = vgg.cuda()

content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']


def get_style_model_and_loss(style_img, content_img, cnn=vgg,
                             style_weight=1000,
                             content_weight=1,
                             content_layers=content_layers_default,
                             style_layers=style_layers_default):

    content_loss_list = []
    style_loss_list = []

    model = nn.Sequential()
    model = model.cuda()
    gram = loss.Gram()
    gram = gram.cuda()

    i = 1
    for layer in cnn:
        if isinstance(layer, nn.Conv2d):
            name = 'conv_' + str(i)
            model.add_module(name, layer)

            if name in content_layers_default:
                target = model(content_img)
                content_loss = loss.Content_Loss(target, content_weight)
                model.add_module('content_loss_' + str(i), content_loss)
                content_loss_list.append(content_loss)

            if name in style_layers_default:
                target = model(style_img)
                target = gram(target)
                style_loss = loss.Style_Loss(target, style_weight)
                model.add_module('style_loss_' + str(i), style_loss)
                style_loss_list.append(style_loss)

            i += 1
        if isinstance(layer, nn.MaxPool2d):
            name = 'pool_' + str(i)
            model.add_module(name, layer)

        if isinstance(layer, nn.ReLU):
            name = 'relu' + str(i)
            model.add_module(name, layer)

    return model, style_loss_list, content_loss_list

训练模型

def get_input_param_optimier(input_img):
    """
    input_img is a Variable
    """
    input_param = nn.Parameter(input_img.data)
    optimizer = optim.LBFGS([input_param])
    return input_param, optimizer


def run_style_transfer(content_img, style_img, input_img,
                       num_epoches=300):
    print('Building the style transfer model..')
    model, style_loss_list, content_loss_list = get_style_model_and_loss(
        style_img, content_img
    )
    input_param, optimizer = get_input_param_optimier(input_img)

    print('Opimizing...')
    epoch = [0]
    while epoch[0] < num_epoches:

        def closure():
            input_param.data.clamp_(0, 1)

            model(input_param)
            style_score = 0
            content_score = 0

            optimizer.zero_grad()
            for sl in style_loss_list:
                style_score += sl.backward()
            for cl in content_loss_list:
                content_score += cl.backward()

            epoch[0] += 1
            if epoch[0] % 50 == 0:
                print('run {}'.format(epoch))
                print('Style Loss: {:.4f} Content Loss: {:.4f}'.format(
                    style_score.data[0], content_score.data[0]
                ))
                print()

            return style_score + content_score

        optimizer.step(closure)

        input_param.data.clamp_(0, 1)

    return input_param.data

需要特别注意的是这个模型里面参数不再是网络里面的参数,因为网络使用的是已经预训练好的 vgg 网络,这个算法里面的参数是合成图片里面的每个像素点,我们可以将内容图片直接 copy 成合成图片,然后训练使得他的风格和我们的风格图片相似,同时也可以随机化一张图片作为合成图片,然后训练他使得他与内容图片以及风格图片具有相似性。

实验结果

我们使用的风格图片为

style.png

内容图片为

content.png

得到的合成效果为

demo.png

结语

通过这篇文章,我们利用pytorch实现了基本的风格转移算法,得到的效果也是满意的,所以我们可以把自己的图片通过这个算法做一个风格转移,实现你想要的作品的风格,逼格满满,大家学习之后肯定会有特别大的成就感,在完成项目的同时也学习到了新的知识,同时也会对这个产生更浓厚的感兴趣,兴趣才是各种的动力,比任何鸡汤都有用,希望大家都能够找到自己的兴趣,热爱自己所做的事。


本文代码已经上传到了github

欢迎查看我的知乎专栏,深度炼丹

欢迎访问我的博客

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 204,445评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,889评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,047评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,760评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,745评论 5 367
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,638评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,011评论 3 398
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,669评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,923评论 1 299
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,655评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,740评论 1 330
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,406评论 4 320
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,995评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,961评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,197评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,023评论 2 350
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,483评论 2 342

推荐阅读更多精彩内容