研究线性模型训练中损失变化的规律和最优学习率的影响

探究一维线性模型训练中,测试损失随训练步数变化的缩放定律及其最优学习率影响,并研究多维线性模型训练的缩放定律,确定参数以符合特定损失衰减模式。

研究大模型的缩放定律对减少其训练开销至关重要,即最终的测试损失如何随着训练步数和模型大小的变化而变化?本题中,我们研究了训练线性模型时的缩放定律。

  1. 在本小问中,考虑使用梯度下降学习一个一维线性模型的情况。
  • 定义数据分布\mathcal{D}为一个\mathbb{R}^2上的分布,每个数据是一个数对(x, y),分别代表输入和输出,并服从分布x\sim N(0, 1),y\sim N(3x, 1)

  • 用梯度下降算法学习线性模型f_{w}(x)=w \cdot x,其中w, x\in\mathbb{R}。初始化ω_0=0并进行多步迭代。每次迭代时,从\mathcal{D}中采样(x_t,y_t),然后更新w_tw_{t+1}\leftarrow w_t-\eta\nabla l_t(w_t),其中l_t(w)=\frac{1}{2}(f_w(x_t)-y_t)^2是平方损失函数,\eta>0是学习率。

设学习率\eta\in(0,\frac{1}{3}],那么T≥0步迭代之后的测试损失的期望

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{w_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{w_T}(x)-y)^2]

是多少?

  1. 现在我们在第一小问的设定下,考虑学习率\eta被调到最优的情况,求函数g(T),使得当T\rightarrow+\infty时,以下条件成立:

\left|\underset{η\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O(\frac{(\log T)^2}{T^2})

  1. 一个常常被观测到的实验现象是大语言模型的预训练过程大致遵循Chinchilla缩放定律:

\overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C

其中\overline{\mathcal{L}}_{N,T}是在经过T步训练后具有N个参数的模型的测试损失的期望,ABaβC是常数。现在我们举一个训练多维线性模型的例子,使其也遵循类似的缩放定律。

  • 固定a>0,b≥1,每个数据(x_{\cdot},y)由一个输入和输出组成,其中输入x_{\cdot}是一个无限维向量(可看作一个序列),输出y满足y\in\mathbb{R}。定义数据分布\mathcal{D}如下。首先,从Zipf分布中采样k\Pr[k=i]\propto i^{-(a+1)}\quad(i\geq 1)。令j:=[k^b],然后,从mathcal{N}(0,1)中采样得到x_{\cdot}的第j个坐标x_j,并令其余坐标为0。最后,y\sim N(3x_j,1)。这样得到的(x_{\cdot},y)的分即数据分布\mathcal{D}

  • 我们研究一个仅关注前N个输入坐标的线性模型。定义函数\phi_N(xx_{\cdot})=(x_1,...,x_N)。我们研究的线性模型具有参数\mathbf{w}\in\mathbb{R}^N,输出为f_{\mathbf{w}}(x)=(\mathbf{w},\phi_N(x_{\cdot}))

  • 我们使用梯度下降算法学习该线性模型。初始化\mathbf{w}_0=0并进行多步迭代。每次迭代时,从\mathcal{D}中采样(x_{t,\cdot},y_t),然后更新\mathbf{w}_t\mathbf{w}_{t+1}\gets \mathbf{w}_t-\eta\nabla l_t(\mathbf{w}_t),其中l_t(\mathbf{w})=\frac{1}{2}(f_\mathbf{w}(x_{t,\cdot})-y_t)^2

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{\mathbf{w}_T}\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(f_{\mathbf{w}_T}(x)-y)^2]为以学习率\eta\in(0,\frac{1}{3}]对其有N个参数的线性模型进行T≥0步训练后的测试损失的期望。

请求出αβC,使得\forall\gamma>0,\forall c>0,当T=N^{c+o(1)}N足够大时,以下条件成立:

\epsilon(N,T):=\frac{\inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}-C}{\frac{A}{N^\alpha}+\frac{B}{T^\beta}}

(\log N+\log T)^{-γ}\leq \epsilon(N,T)\leq(\log N+\log T)^γ。即\inf_{\eta\in(0,\frac{1}{3}]}{\overline{\mathcal{L}}_{N,T}}=\tilde{\Theta}(N^{-\alpha}+T^{-\beta})+C,其中\tilde{\Theta}表示忽略任何关于\log N\log T的多项式。

解:

  1. 首先,我们来计算测试损失的期望\overline{\mathcal{L}}_{\eta,T}

由于xy是独立的随机变量,且y的条件分布是N(3x, 1),我们可以写出测试损失的期望为:

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{(x,y)\sim D}[\frac{1}{2}(w_T x - y)^2]

由于y=3x+\epsilon,其中\epsilon\sim N(0, 1)且独立于x,我们可以将y替换为3x+\epsilon

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_{x,\epsilon}[\frac{1}{2}(w_T x - (3x+\epsilon))^2]

展开并利用\mathbb{E}[\epsilon^2]=1\mathbb{E}[x^2]=1(因为x\sim N(0, 1)):

\overline{\mathcal{L}}_{\eta,T}=\mathbb{E}_x[\frac{1}{2}(w_T^2 x^2 - 6w_T x^2 + 9x^2 + \epsilon^2 - 6w_T x \epsilon + 3w_T^2 x^2)]

由于\epsilonx是独立的,我们可以分别计算期望:

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9)\mathbb{E}[x^2] + \frac{1}{2}\mathbb{E}[\epsilon^2]

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(w_T^2 - 6w_T + 9) + \frac{1}{2}

现在我们需要计算w_T的期望值。由于w_t的更新规则是w_{t+1}=w_t-\eta\nabla l_t(w_t),我们有:

\nabla l_t(w_t) = w_t x_t - y_t = w_t x_t - (3x_t + \epsilon)

因此,更新规则变为:

w_{t+1} = w_t - \eta(w_t x_t - 3x_t - \epsilon)

取期望并利用\mathbb{E}[x_t]=0\mathbb{E}[\epsilon]=0

\mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - \eta(3\mathbb{E}[x_t^2])

由于x_t^2的期望是1,我们有:

\mathbb{E}[w_{t+1}] = \mathbb{E}[w_t] - 3\eta

由于w_0=0,我们可以递归地计算w_T

\mathbb{E}[w_T] = -3\eta T

\mathbb{E}[w_T]代入测试损失的期望中:

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}((-3\eta T)^2 - 6(-3\eta T) + 9) + \frac{1}{2}

\overline{\mathcal{L}}_{\eta,T}=\frac{1}{2}(9\eta^2 T^2 + 18\eta T + 9) + \frac{1}{2}

\overline{\mathcal{L}}_{\eta,T}=\frac{9\eta^2 T^2 + 18\eta T + 10}{2}

  1. 接下来,我们需要找到g(T)

首先,我们需要最小化\overline{\mathcal{L}}_{\eta,T}关于\eta。我们可以通过设置\frac{d\overline{\mathcal{L}}_{\eta,T}}{d\eta}=0来找到最优的学习率\eta^*

\frac{d}{d\eta}(\frac{9\eta^2 T^2 + 18\eta T + 10}{2})=9\eta T^2 + 18T=0

解得:

\eta^* = \frac{2}{3T}

\eta^*代入\overline{\mathcal{L}}_{\eta,T}中,我们得到最小化测试损失的表达式:

\overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{2}{3T})^2 T^2 + 18(\frac{2}{3T}) T + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{9(\frac{4}{9T^2}) T^2 + 18(\frac{2}{3T}) T + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{4 + 12 + 10}{2}

\overline{\mathcal{L}}_{\eta^*,T}=\frac{26}{2}

\overline{\mathcal{L}}_{\eta^*,T}=13

现在,我们需要找到g(T),使得当T\rightarrow+\infty时,以下条件成立:

\left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right)

由于我们已经找到了最优的学习率\eta^*,我们可以将\overline{\mathcal{L}}_{\eta^*,T}视为\mathcal{I}_{n,T}的下界。因此,我们需要找到一个函数g(T),使得当T趋向于无穷大时,\overline{\mathcal{L}}_{\eta^*,T}g(T)之间的差异满足上述条件。

考虑到\overline{\mathcal{L}}_{\eta^*,T}是一个常数13,我们可以推断g(T)应该也是一个常数,因为测试损失的期望在最优学习率下不随T变化。因此,我们可以选择g(T)=13

现在,我们需要验证这个选择是否满足条件:

\left|\underset{\eta\in(0,\frac{1}{3}]}{\inf}\mathcal{I}_{n,T}-g(T)\right|=O\left(\frac{(\log T)^2}{T^2}\right)

由于\mathcal{I}_{n,T}的最小值是13,我们有:

\left|13-13\right|=0

显然,0=O\left(\frac{(\log T)^2}{T^2}\right),因为当T趋向于无穷大时,\frac{(\log T)^2}{T^2}趋向于0。因此,我们的选择g(T)=13是正确的。

综上所述,g(T)=13满足题目中的条件。

3.为了解决这个问题,我们需要推导出多维线性模型在给定数据分布下的缩放定律。根据题目描述,我们有一个线性模型,其参数遵循特定的缩放定律。我们将通过以下步骤来解决这个问题:

步骤 1: 理解数据分布

数据分布 \mathcal{D} 是通过 Zipf 分布来选择输入向量的非零坐标,然后根据该坐标的值来生成输出 y。这意味着大部分的数据集中在较少的非零坐标上。

步骤 2: 定义损失函数

损失函数 \overline{\mathcal{L}}_{\eta,T} 是在给定学习率 \eta 和训练步数 T 后,模型参数 \mathbf{w} 的测试损失的期望。

步骤 3: 推导缩放定律

我们需要找到 \alpha\beta,和 C 使得损失函数符合 \overline{\mathcal{L}}_{N,T}≈\frac{A}{N^\alpha}+\frac{B}{T^\beta}+C 的形式。

对于 \alpha 的推导:

  • 参数 N 表示模型考虑的输入向量的维度。由于数据分布的特性,大部分的权重不会接收到有效的梯度更新,因为它们对应的输入坐标为零。因此,增加 N 的数量不会显著改善模型的性能,但也不会损害它,因为只有少数权重会被更新。

  • Zipf 分布的特性意味着非零坐标的数量随着 N 的增加而减少。因此,我们可以预期 \alpha 大于 0,但小于 1,因为增加维度对于模型性能的提升是有上限的。

对于 \beta 的推导:

  • 参数 T 表示训练步数。随着训练步数的增加,模型将获得更多的机会来更新其权重,从而减少损失。因此,我们可以预期 \beta 大于 0。

  • 由于数据分布的特性,并不是每一步都会对所有权重进行有效更新。因此,\beta 可能不会是 1,而是小于 1 的某个值。

对于 C 的推导:

  • 常数 C 表示当 NT 趋于无穷大时,测试损失的最低值。这是由于数据本身的噪声和模型的能力限制导致的。

步骤 4: 确定 \alpha\beta,和 C

为了确定 \alpha\beta,和 C,我们需要进行以下分析:

  • 对于 \alpha:考虑到只有少数权重会被更新,我们可以假设 \alpha 在 0 和 1 之间。更具体地,由于 Zipf 分布的特性,我们可以假设 \alpha 接近于 1,但小于 1,因为随着 N 的增加,额外维度的边际贡献会减少。一个合理的猜测是 \alpha = \frac{1}{b}

  • 对于 \beta:考虑到每一步并不是对所有权重都进行有效更新,我们可以假设 \beta 小于 1。一个合理的猜测是 \beta = \frac{1}{2},这是因为通常情况下,梯度下降的收敛速度与步数的平方根成反比。

  • 对于 C:这是数据噪声和模型表达能力限制的结果。在没有更多信息的情况下,我们无法精确确定 C,但可以假设它是一个正数。

步骤 5: 验证条件

我们需要验证 \epsilon(N,T) 的条件是否成立。这通常涉及到对 \overline{\mathcal{L}}_{N,T} 进行详细的分析,并证明它符合给定的缩放形式。这通常需要数学上的证明和/或实验验证。

综上所述,我们可以假设 \alpha = \frac{1}{b}\beta = \frac{1}{2}C 是一个正数。然而,为了得到精确的值,我们需要更深入的分析和实验数据。在实际应用中,这些参数通常是通过实验来确定的。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,470评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,393评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,577评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,176评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,189评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,155评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,041评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,903评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,319评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,539评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,703评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,417评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,013评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,664评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,818评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,711评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,601评论 2 353

推荐阅读更多精彩内容