本文从CSDN上转移过来:
http://blog.csdn.net/mounty_fsc/article/details/51588773
在Solver::ApplyUpdate()函数中,根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习。
1 模型优化
1.1 损失函数
损失函数$L(W)$可由经验损失加正则化项得到,如下,其中$X^{(i)}$为输入样本;$f_W$为某样本的损失函数;$N$为mini-batch的样本数量;$r(W)$为以权值为$\lambda$的正则项。
$L(W) \approx \frac{1}{N} \sum_i^N f_W\left(X^{(i)}\right) + \lambda r(W)$
在caffe中,可以分为三个阶段:
- 前向计算阶段,这个阶段计算$f_W$
- 反向传播阶段,这个阶段计算$\nabla f_W$
- 权值更新阶段,这个阶段通过$\nabla f_W,\nabla r(W)$等计算$\Delta W$从而更新$W$
1.2 随机梯度下降
在lenet中,solver的类型为SGD(Stochastic gradient descent)
SGD通过以下公式对权值进行更新:
$W_{t+1} = W_t + V_{t+1}$
$V_{t+1} = \mu V_t - \alpha \nabla L(W_t)$
其中,$W_{t+1}$为第$t+1$轮的权值;$V_{t+1}$为第$t+1$轮的更新(也可以写作$\Delta W_{t+1}$);$\mu$为上一轮更新的权重;$\alpha$为学习率;$\nabla L(W_t)$为loss对权值的求导
2 代码分析
2.1 ApplyUpdate
void SGDSolver<Dtype>::ApplyUpdate() {
// 获取该轮迭代的学习率(learning rate)
Dtype rate = GetLearningRate();
// 对每一层网络的权值进行更新
// 在lenet中,只有`conv1`,`conv2`,`ip1`,`ip2`四层有参数
// 每层分别有参数与偏置参数两项参数
// 因而`learnable_params_`的size为8.
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
// 归一化,iter_size为1不需要,因而lenet不需要
Normalize(param_id);
// 正则化
Regularize(param_id);
// 计算更新值\delta w
ComputeUpdateValue(param_id, rate);
}
// 更新权值
this->net_->Update();
}
说明:
-
lenet中学习参数设置可从
lenet_solver.prototxt
中查到# The base learning rate, momentum and the weight decay of the network. base_lr: 0.01 momentum: 0.9 weight_decay: 0.0005 # The learning rate policy lr_policy: "inv" gamma: 0.0001 power: 0.75
-
获取学习率函数ApplyUpdate代码此处不给出,查看注释(以及caffe.proto)可知有如下学习率获取策略。在Lenet中采用的是
inv
的策略,是一种没一轮迭代学习率都改变的策略。// The learning rate decay policy. The currently implemented learning rate // policies are as follows: // - fixed: always return base_lr. // - step: return base_lr * gamma ^ (floor(iter / step)) // - exp: return base_lr * gamma ^ iter // - inv: return base_lr * (1 + gamma * iter) ^ (- power) // - multistep: similar to step but it allows non uniform steps defined by // stepvalue // - poly: the effective learning rate follows a polynomial decay, to be // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) // - sigmoid: the effective learning rate follows a sigmod decay // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) // // where base_lr, max_iter, gamma, step, stepvalue and power are defined // in the solver parameter protocol buffer, and iter is the current iteration.
2.2 Regularize
该函数实际执行以下公式
$\nabla w_{ij}=decay*w_{ij}+\nabla w_{ij}$
代码如下:
void SGDSolver<Dtype>::Regularize(int param_id) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
// local_decay = 0.0005 in lenet
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
...
if (regularization_type == "L2") {
// axpy means ax_plus_y. i.e., y = a*x + y
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
}
...
}
2.3 ComputeUpdateValue
该函数实际执行以下公式
$\nabla w_{ij}=lr_rate\nabla w_{ij}+momentumw^{'}_{ij}$
$w^{'}$为上一轮的权值,注意结果保存的位置在cpu_diff
中即loss对参数的梯度中
代码如下:
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
// momentum = 0.9 in lenet
Dtype momentum = this->param_.momentum();
// local_rate = lr_mult * global_rate
// lr_mult为该层学习率乘子,在lenet_train_test.prototxt中设置
Dtype local_rate = rate * net_params_lr[param_id];
// Compute the update to history, then copy it to the parameter diff.
...
// axpby means ax_plus_by. i.e., y = ax + by
// 计算新的权值更新变化值 \delta w,结果保存在历史权值变化中
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
// 从历史权值变化中把变化值 \delta w 保存到历史权值中diff中
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
...
}
2.4 net_->Update
实际执行以下公式:
$w_{ij}=w_{ij}+(-1)*\nabla w_{ij}$
参考文献: