0003-keras自定义优化器

原文

keras优化器的代码

自定义一个SGD优化器

from keras.legacy import interfaces
from keras.optimizers import Optimizer
from keras import backend as K
class SGD(Optimizer):
    def __init__(self,lr=0.01,**kwargs):
        super(SGD,self).__init__(**kwargs)
        with K.name_scope(self.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            
    @interfaces.legacy_get_updates_support
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params) #获取梯度
        self.updates = [K.update_add(self.iterations, 1)] # 定义赋值算子集合
        self.weights = [self.iterations] #优化器带来的权重,在保存模型时会被保存
        for p, g in zip(params, grads):
            new_p =p -self.lr*g
            #如果有约束,对参数加上约束
            if getattr(p, 'constraint', None) is not None: 
                new_p = p.constraint(new_p)
            #添加赋值
            self.updates.append(K.update(p, new_p))
        return self.updates

    def get_config(self):
        config = {'lr': float(K.get_value(self.lr))}
        base_config = super(SGD,self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
 

实现“软batch”

假如模型比较庞大,自己的显卡最多也就能跑 batch size=16,但又想起到 batch size=64 的效果,那可以怎么办呢?
每次算 batch size=16,然后把梯度缓存起来,4 个 batch 后才更新参数。也就是说,每个小batch 都算梯度,但每 4 个 batch 才更新一次参数。

class MySGD(Optimizer):
    """
    Keras中简单自定义SGD优化器每隔一定的batch才更新一次参数
    """
    def __init__(self, lr=0.01, steps_per_update=1, **kwargs):
        super(MySGD, self).__init__(**kwargs)
        
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.steps_per_update = steps_per_update #多少batch才更新一次
            
    @interfaces.legacy_get_updates_support
    def get_updates(self, loss, params):
        """
        主要的参数更新算法
        """
        shapes = [K.int_shape(p) for p in params]
        sum_grads = [K.zeros(shape) for shape in shapes] # 平均梯度,用来梯度下降
        grads = self.get_gradients(loss, params) # 当前batch梯度
        self.updates = [K.update_add(self.iterations, 1)] # 定义赋值算子集合
        self.weights = [self.iterations] + sum_grads # 优化器带来的权重,在保存模型时会被保存
        for p, g, sg in zip(params,grads,sum_grads):
            #梯度下降
            new_p = p - self.lr * sg / float(self.steps_per_update)
            
            #如果有约束,对参数加上约束
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)
            cond = K.equal(self.iterations % self.steps_per_update, 0)
            
            #满足条件才更新参数
            self.updates.append(K.switch(cond, K.update(p, new_p), p))
            
            #满足条件就要重新累积,不满足条件直接累积
            self.updates.append(K.switch(cond, K.update(sg, g), K.update(sg, sg+g)))
            
        return self.updates
    def get_config(self):
        config = {'lr': float(K.get_value(self.lr)),'steps_per_update': self.steps_per_update}
        base_config = super(MySGD, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

“侵入式”优化器

image.png

其中 ,p 是参数向量,g 是梯度, 表示 p 的第 i 次迭代时的结果。
这个算法需要走两步,大概意思就是普通的梯度下降先走一步(探路),然后根据探路的结果取平均,得到更精准的步伐,等价地可以改写为:
image.png

但是实现这类算法却有个难题,要计算两次梯度,一次对参数 ,另一次对参数。而前面的优化器定义中 get_updates 这个方法却只能执行一步(对应到 tf 框架中,就是执行一步 sess.run,熟悉 tf 的朋友知道单单执行一步 sess.run 很难实现这个需求),因此实现不了这种算法。

class HeunOptimizer:
    """
    自定义Keras的侵入式优化器
    """

    def __init__(self, lr):
        self.lr = lr

    def __call__(self, model):
        """
        需要传入模型,直接修改模型的训练函数,而不按常规流程使用优化器,所以称为“侵入式”
        下面的大部分代码,都是直接抄自keras的源码:
        https://github.com/keras-team/keras/blob/master/keras/engine/training.py#L491
        也就是keras中的_make_train_function函数
        """
        params = model._collected_trainable_weights
        loss = model.total_loss

        inputs = (model._feed_inputs + model._feed_targets + model._feed_sample_weights)
        inputs += [K.learning_phase()]

        with K.name_scope('training'):
            
            with K.name_scope('heun_optimizer'):
                
                old_grads = [[K.zeros(K.int_shape(p)) for p in params]]
                update_functions = []
                
                for i,step in enumerate([self.step1, self.step2]):
                    updates = (model.updates + step(loss, params, old_grads) + model.metrics_updates)
                    #给每一步定义一个K.function
                    updates = K.function(inputs,[model.total_loss]+model.metrics_tensors,updates=updates,name='train_function_%s'%i,**model._function_kwargs)
                    update_functions.append(updates)

                def F(ins):
                    # 将多个K.function封装为一个单独的函数
                    # 一个K.function就是一次sess.run
                    for f in update_functions:
                        _ = f(ins)
                    return _

                # 最后只需要将model的train_function属性改为对应的函数
                model.train_function = F

    def step1(self, loss, params, old_grads):
        ops = []
        grads = K.gradients(loss, params)
        for p,g,og in zip(params, grads, old_grads[0]):
            ops.append(K.update(og, g))
            ops.append(K.update(p, p - self.lr * g))
        return ops

    def step2(self, loss, params, old_grads):
        ops = []
        grads = K.gradients(loss, params)
        for p,g,og in zip(params, grads, old_grads[0]):
            ops.append(K.update(p, p - 0.5 * self.lr * (g - og)))
        return ops
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,711评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,079评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,194评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,089评论 1 286
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,197评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,306评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,338评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,119评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,541评论 1 306
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,846评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,014评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,694评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,322评论 3 318
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,026评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,257评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,863评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,895评论 2 351

推荐阅读更多精彩内容