恒源云(GPUSHARE)_Child Tuning: 反向传播版的Dropout

文章来源 | 恒源云社区

原文地址 | Child Tuning: 反向传播版的Dropout

原文作者 | Mathor


这篇文章主要是对EMNLP2021上的论文Raise a Child in Large Language Model: Towards Effective and Generalizable Fine-tuning进行讲解。论文标题有些抽象,但是用作者的话来说,这篇论文的思想可以归结为两个词:Child Tuning

虽然这篇文章主要针对NLP任务以及NLP相关的模型,但实际上我看完之后觉得这是一个通用的方法,CV领域也可以使用。具体来说,目前预训练模型的参数非常大,在下游任务中,我们只能用有限的训练集对模型进行微调,有一种螳臂当车的感觉,因此作者提出了一种新的微调方法——Child Tuning。如果用一句话概述其思想那就是:在反向传播过程中,我们不用更新所有的参数,只更新某些参数即可,而这些被更新的参数所对应的网络结构,我们叫做Child Network(子网络)

如上图所示,上面一行是正常的反向传播过程,其中


下标0不是指某一个参数,而是指第0个迭代过程,\eta是学习率。对于下面一行来说,Δw _0有一部分被MASK掉了,导致这里面的梯度为0

其中,M矩阵内的元素非0即1,\odot是矩阵内的元素做对应位置相乘。我们可以用两步来概括Child Tuning的过程:

  1. 在预训练模型中发现并确认Child Network,并生成对应Weights的0-1 MASK
  2. 反向传播计算完梯度后,仅对Child Network中的参数进行更新

所以现在的问题是如何确认Child Network?

HOW TO FIND CHILD NETWORK?
实际上我们并不需要真的找到Child Network,只要确定矩阵M即可。论文提供了两种算法用于生成矩阵M,分别是任务无关算法Child_Tuning_F (F for Task-Free)以及与具体任务相关的算法Child_Tuning_D (D for Task-Drivern)

Child_Tuning_F
任务无关算法的意思是与你具体所做的具体任务没有关系,都可以使用这个算法,是一种通用的方法。具体来说,此时M是根据伯努利分布生成的

其中p_F\in [0,1]是一个超参数,他控制着Child Network的大小,如果p_F=1,则Child Network就是原网络,此时Child Tuning就是Fine Tuning;如果p_F=0,则没有任何参数会被更新。下面是我写的一个简单模拟的代码帮助大家理解

import torch
from torch.distributions.bernoulli import Bernoulli

gradient = torch.randn((3, 4)) # 这里用一个随机生成的矩阵来代表梯度
p_F = 0.2
gradient_mask = Bernoulli(gradient.new_full(size=gradien.size(), fill_value=p_F))
gradient_mask = gradient_mask.sample() / p_F # 除以p_F是为了保证梯度的期望不变
print(gradient_mask)

gradient *= gradient_mask
print(gradient)

Bernoulli是一个类,生成的gradient_mask是一个对象,我们需要调用这个对象的sample()方法才能得到一个矩阵。其中比较重要的一点是虽然我们得到了0-1 MASK,但我们需要将这个MASK内所有的1扩大1/p_F倍以维持梯度的期望值

别的梯度都不在了,活着的梯度要带着其他人的意志坚强的反向传播下去啊!

Child_Tuning_D

考虑到存在不同的下游任务,作者提出一种与具体任务相关的算法Child_Tuning_D,它可以检测出对目标任务最重要的子网络(或者参数)。具体来说,作者采用Fisher信息估计法来寻找与特定下游任务高度相关的参数。形式上,模型参数w的Fisher Information Matrix(FIM)定义如下:

其中,x,y分别是输入和输出,由此我们可以推出第i个参数的Fisher信息如下:

其中,|D|是所有样本的数量。作者认为,参数对目标任务越重要,其Fisher信息越大,因此Child Tuning是由Fisher信息最高的那些参数组成,此时Child Network的比例为

其中|\bar{\mathcal{C}}|表示非子网络,当p_D=1时,Child Tuning就退化为了Fine Tuning。实际上Fisher信息的计算是相当耗时的,如果我们每次反向传播后都去计算一次所有参数的Fisher信息,然后找出最大的前几个是很麻烦的,因此作者提出在真正开始训练之前,我们先对所有样本进行一次完整(一个Epoch)的前向传播和反向传播,此时计算出Fisher信息最高的那些参数,以及此时确定的Child Network以后就不再变化了,就以这一次所选定的为准

下面给出计算Fisher信息的代码

def calculate_fisher():
    gradient_mask, p_F = {}, 0.2
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
    N = len(train_dataloader) # N = |D|
    for name, params in model.named_parameters():
        if 'layer' in name:
            gradient_mask[params] = params.new_zeros(params.size())
    for batch in train_loader:
        outpus = model(**batch)
        loss = outpus['loss'] if isinstance(outpus, dict) else outputs[0]
        loss.backward()

        for name, params in model.named_parameters():
            if 'layer' in name:
                torch.nn.utils.clip_grad_norm(params, 1)
                gradient_mask[params] += (params.grad ** 2) / N
        model.zero_grad()
    
    r = None
    for k, v in gradient_mask.items():
        v = v.view(-1).cpu().numpy() # flatten
        if r is None:
            r = v
        else:
            r = np.append(r, v)
    
    # polar = np.percentile(a, q) # a中有q%的元素小于polar
    polar = np.percentile(r, (1-p_F)*100)
    for k in gradient_mask:
        gradient_mask[k] = gradient_mask[k] >= polar
    print('Polar => {}'.format(polar))

    return gradient_mask

PROOF
如果这篇论文就讲了这些东西,很大概率是中不了EMNLP的,之所以被录用了,我个人觉得和这篇论文里大量的证明有关,作者证明了使用Child Tuning可以帮助模型逃离局部极小值点,接下来我尝试着把论文中的证明部分说清楚

首先我们假设\mathbf{g}^{(i)}是给定样本x^{(i)}时参数w的梯度,并且它服从正态分布g (i) ∼N( \frac{∂w}{∂L} ,σ_g^2 I _{k} ),定义g=∑^{∣B∣}_{i=1}\frac{g^{(i)}}{∣B∣},则有

对于\mathbf{g},我们有

\hat{\mathbf{g}} = \frac{\mathbf{g}}{p}\odot M\,其中pp_D或p_F(看你用的哪种算法),则

上面的公式推导其实并不严格,例如分子的p是从哪来的就没法解释,分子的p只有可能是\mathbb{E}[M]的结果,可是M是个矩阵,矩阵的期望怎么就变成一个数了呢?但要强行解释也可以,因为将M中所有的1加起来除以M内的所有元素似乎也是等于p的设\hat{g_i}, g_i 分别是\hat{\mathbf{g}}, \mathbf{g} ,i维度上的值,那么有\hat{g_i} = \frac{g_i}{p}\odot M_i

因此


最终我们就得到


特别地,当参数w训练到局部极小值点时,\frac{∂w}{∂L} =0,此时E[Δw]=0,Σ[Δw]= \frac{η ^2σ _g^2 I _k}{p^{∣B∣}},我们注意到Σ[Δw]是关于p的一个递减函数,p越大,Σ[Δw]越小,极端情况是p=1,此时Child Tuning退化为Fine Tuning,并且Σ[Δw]最小,相当于它的变化量每次都不大,因此就很难跳出局部极小值点;p越小,Σ[Δw]越大,相当于它的变化量每次都很大,因此比较容易跳出局部极小值点

个人总结

这篇论文刚读的时候觉得很厉害,但实际上了解之后就觉得这其实就是一个反向传播版的Dropout,实际的创新并没有特别大,包括其中提到的Fisher信息也并不是这篇论文提出来的。再就是论文中的实验确实很多,实验结果表明,相比于Fine Tuning大约可以提升1.5~8.6个点不等。最后要说一下这篇论文的公式证明部分,我个人觉得这篇论文的证明其实没有很严谨,例如为什么一个矩阵的期望就变成一个数了。总的来说这个方法可以作为打比赛时候的一个Trick来使用。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 220,548评论 6 513
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 94,069评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 166,985评论 0 357
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 59,305评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 68,324评论 6 397
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 52,030评论 1 308
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,639评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,552评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 46,081评论 1 319
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,194评论 3 340
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,327评论 1 352
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 36,004评论 5 347
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,688评论 3 332
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,188评论 0 23
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,307评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,667评论 3 375
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,337评论 2 358

推荐阅读更多精彩内容