CapsNet学习笔记

本文有一些公式,由于简书不支持LaTeX公式渲染,公式完整版请移步个人博客
欢迎转载,转载请注明出处(简书地址和个人博客地址均可)

理论学习

胶囊结构

胶囊可以看成一种向量化的神经元。对于单个神经元而言,目前的深度网络中流动的数据均为标量。例如多层感知机的某一个神经元,其输入为若干个标量,输出为一个标量(不考虑批处理);而对于胶囊而言,每个神经元输入为若干个向量,输出为一个向量(不考虑批处理)。前向传播如下所示:

capsule_structure.png

其中$I_i$为第i个输入(向量),$W_i$为第i个权值(矩阵),$U_i$为中间变量(向量),由输入和权值叉乘获得。$c_i$为路由权值(标量),需要注意的是该标量是前向传播过程中决定(使用动态路由算法)的,不是通过反向传播优化的参数。Squash为一种激活函数。前向传播使用公式表示如下所示:

$$U_i = W_i^T \times I_i$$

$$S = \sum \limits_{i = 0}^n c_i \cdot U_i$$

$$Result = Squash(S) = \cfrac{||S||2}{1+||S||2} \cdot \cfrac{S}{||S||}$$

由以上可以看出,胶囊结构中流动的数据类型为向量,其激活函数Squash输入一个向量,输出一个向量。

动态路由算法

动态路由算法适用于确定胶囊结构中$c_i$的算法,其算法伪代码如下所示:

dynamic_route.jpg

首先其输入为$U_{j|i}$为本层的中间变量,其中i为这一层胶囊数量,j为下一层胶囊数量,最终获得的胶囊的输出$v_j$,其步骤描述如下:

  1. 初始化:初始化一个临时变量b,为一个$i \times j$的全为0的矩阵
  2. 获取这一步的连接权值c:$c_i = softmax(b_i)$,将临时变量b通过softmax,保证$c_i$的各分量和为1
  3. 获取这一步的加权和结果S:$s_j = \sum_i c_{ij}u_{j|i}$,按这一步连接权值计算加权和
  4. 非线性激活:$v_j = squash(s_j)$,经过非线性激活函数,获取这一步的胶囊输出
  5. 迭代临时变量:$b_{ij} = b_{ij} + u_{i|j} \cdot v_{j}$,所这一步的输出与中间变量方向相近,增加临时变量b,即增加权值;若这一步输出与中间变量方向相反,减小临时变量b,即减小权值。
  6. 若已经迭代到指定次数,输出$v_j$,否侧跳到步骤2

同时,对于迭代次数j,论文中表示过多的迭代会导致过拟合,实践中建议使用3次迭代。

输出与代价函数

输出层胶囊的输出为向量,该向量的长度即为概率。也就是说,前向传播的结果为输出最长向量的输出胶囊所代表的结果。反向传播时,也需要考虑网络的输出为向量而不是标量,因此原论文中了如下的代价函数(每个输出的代价函数,代价函数为所有输出代价函数的和$L = \sum\limits_{c=0}^n L_c$)

$$L_c = T_c max(0,m^+ - ||V_c||)^2 + \lambda (1 - T_c)max(0,||v_c|| - m^-) ^ 2$$

其中,$T_c$为标量,当分类结果为c时$T_c = 1$,否则$T_c = 0$;$\lambda$为固定值(一般为0.5),用于保证数值稳定性;$m+$和$m-$也为固定值:

  • 对于$T_c = 1$的输出胶囊,当输出向量大于$m^+$时,代价函数为0,否则不为0
  • 对于$T_c = 0$的输出胶囊,当输出向量小于$m^-$时,代价函数为0,否则不为0

整体架构

原论文中使举了一个识别MNIST手写数字数据集的例子,网络架构如下图所示:

capsnet_mnist.jpg
  • 第一层为普通的卷积层,使用9*9卷积,输出通道数为256,输出数据尺寸为20*20*256
  • 第二层为卷积层,该卷积层由平行的32个卷积层组成,每个卷积层对应向量数据中的一个向量。每个卷积层均为9*9*256*8(输入channel为256,输出channel为8)。因此输出为6*6*32*8,即窗口大小为6*6,输出channel为32,每个数据为8个分量的向量。
  • 第三层为胶囊层,行为类似于全连接层。输入为6*6*32=1152个8分量输入向量,输出为10个16分量的向量,对应的有1152*10个权值,每个权值为8*16的矩阵,最终输出为10个16分量的向量
  • 最终输出10个16分量的向量,最终的分类结果是向量长度最大的输出。

代码阅读(PyTorch)

本次代码阅读并不关心具体的实现方式,主要阅读CapsNet的实现思路

前胶囊层(卷积层)

class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9):
        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 
                          for _ in range(num_capsules)])
    
    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), 32 * 6 * 6, -1)
        return self.squash(u)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

重点关注forward前向传播部分:

def forward(self, x):
    u = [capsule(x) for capsule in self.capsules]
    u = torch.stack(u, dim=1)
    u = u.view(x.size(0), 32 * 6 * 6, -1)
    return self.squash(u)

self.capsulesnum_capsules[in_channels,out_channels,kernel_size,kernel_size]的卷积层,对应上文所述的第二层卷积层的操作。注意该部分的输出直接被变为[batch size,1152,8]的形式,且通过squash激活函数挤压输出向量

胶囊层

class DigitCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)

        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = torch.matmul(W, x)

        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
        if USE_CUDA:
            b_ij = b_ij.cuda()

        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

获得中间向量

batch_size = x.size(0)
x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)

W = torch.cat([self.W] * batch_size, dim=0)
u_hat = torch.matmul(W, x)

这一部分计算中间向量$U_i$

动态路由

for iteration in range(num_iterations):
    c_ij = F.softmax(b_ij)
    c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

    s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
    v_j = self.squash(s_j)
            
    if iteration < num_iterations - 1:
        a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
        b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

动态路由的结构中:

  • 第1行计算了softmax函数的结果,对用临时变量b
  • 第5行计算加权和
  • 第6行计算当前迭代次数的输出
  • 第9和10行更新临时向量的值

代价函数

def margin_loss(self, x, labels, size_average=True):
    batch_size = x.size(0)
    v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))
    left = F.relu(0.9 - v_c).view(batch_size, -1)
    right = F.relu(v_c - 0.1).view(batch_size, -1)
    loss = labels * left + 0.5 * (1.0 - labels) * right
    loss = loss.sum(dim=1).mean()
    return loss

该函数为代价函数,分别实现了两种情况下($T_c = 0,T_c = 1$)的代价函数。

参考资料

代码来自higgsfield's github

文字资料参考weakish翻译的Max Pechyonkin的博客:

此外还参考:

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,837评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,551评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,417评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,448评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,524评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,554评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,569评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,316评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,766评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,077评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,240评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,912评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,560评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,176评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,425评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,114评论 2 366
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,114评论 2 352

推荐阅读更多精彩内容