0002-keras自定义层+模型保存后自定义层的加载

自定义Position层

import keras
from keras import backend as K
import tensorflow as tf

class Position_Embedding(keras.layers.Layer):    
    def __init__(self, size=None, mode='sum', **kwargs):
        self.size = size 
        self.mode = mode
        super(Position_Embedding, self).__init__(**kwargs)        
    def call(self, x):
        if (self.size == None) or (self.mode == 'sum'):
            self.size = int(x.shape[-1])
        position_j = 1. / K.pow(  10000., 2 * K.arange(self.size / 2, dtype='float32') / self.size  )
        position_j = K.expand_dims(position_j, 0)
        
        position_i = K.cumsum(K.ones_like(x[:,:,0]), 1)-1 
        position_i = K.expand_dims(position_i, 2)
        position_ij = K.dot(position_i, position_j)
        position_ij = K.concatenate([K.cos(position_ij), K.sin(position_ij)], 2)
        if self.mode == 'sum':
            return position_ij + x
        elif self.mode == 'concat':
            return K.concatenate([position_ij, x], 2)        
    def compute_output_shape(self, input_shape):
        if self.mode == 'sum':
            return input_shape
        elif self.mode == 'concat':
            return (input_shape[0], input_shape[1], input_shape[2]+self.size)

定义Attention层

class Attention(keras.layers.Layer):
    def __init__(self, nb_head, size_per_head, **kwargs):
        self.nb_head = nb_head
        self.size_per_head = size_per_head
        self.output_dim = nb_head*size_per_head
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.WQ = self.add_weight(name='WQ',shape=(int(input_shape[0][-1]), self.output_dim),initializer='glorot_uniform',
                                  trainable=True)
        self.WK = self.add_weight(name='WK',shape=(int(input_shape[1][-1]), self.output_dim),initializer='glorot_uniform',
                                  trainable=True)
        self.WV = self.add_weight(name='WV',shape=(int(input_shape[2][-1]), self.output_dim),initializer='glorot_uniform',
                                  trainable=True)
        super(Attention, self).build(input_shape)
  
    def Mask(self, inputs, seq_len, mode='mul'):
        if seq_len == None:
            return inputs
        else:
            mask = K.one_hot(seq_len[:,0], K.shape(inputs)[1])
            mask = 1 - K.cumsum(mask, 1)
            for _ in range(len(inputs.shape)-2):
                mask = K.expand_dims(mask, 2)
            if mode == 'mul':
                return inputs * mask
            if mode == 'add':
                return inputs - (1 - mask) * 1e12
                
    def call(self, x):
        
        if len(x) == 3:
            Q_seq,K_seq,V_seq = x
            Q_len,V_len = None,None
        elif len(x) == 5:
            Q_seq,K_seq,V_seq,Q_len,V_len = x
            
        Q_seq = K.dot(Q_seq, self.WQ)
        Q_seq = K.reshape(Q_seq, (-1, K.shape(Q_seq)[1], self.nb_head, self.size_per_head))
        Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3))
        
        K_seq = K.dot(K_seq, self.WK)
        K_seq = K.reshape(K_seq, (-1, K.shape(K_seq)[1], self.nb_head, self.size_per_head))
        K_seq = K.permute_dimensions(K_seq, (0,2,1,3))
        
        V_seq = K.dot(V_seq, self.WV)
        V_seq = K.reshape(V_seq, (-1, K.shape(V_seq)[1], self.nb_head, self.size_per_head))
        V_seq = K.permute_dimensions(V_seq, (0,2,1,3))
        
        A = K.batch_dot(Q_seq, K_seq, axes=[3,3]) / self.size_per_head**0.5
        A = K.permute_dimensions(A, (0,3,2,1))
        
        A = self.Mask(A, V_len, 'add')
        A = K.permute_dimensions(A, (0,3,2,1))    
        A = K.softmax(A)
        
        
        O_seq = K.batch_dot(A, V_seq, axes=[3,2])
        O_seq = K.permute_dimensions(O_seq, (0,2,1,3))
        O_seq = K.reshape(O_seq, (-1, K.shape(O_seq)[1], self.output_dim))
        O_seq = self.Mask(O_seq, Q_len, 'mul')
        
        return O_seq
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], input_shape[0][1], self.output_dim)     
    
    def get_config(self):        
        config = {"nb_head":self.nb_head,"size_per_head":self.size_per_head}
        base_config = super(Attention,self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

下面展示模型保存后自定义层的调用

model = models.load_model("/data/user//model.h5",custom_objects={'Position_Embedding':Position_Embedding,"Attention":Attention})
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • Swift1> Swift和OC的区别1.1> Swift没有地址/指针的概念1.2> 泛型1.3> 类型严谨 对...
    cosWriter阅读 11,156评论 1 32
  • 这是16年5月份编辑的一份比较杂乱适合自己观看的学习记录文档,今天18年5月份再次想写文章,发现简书还为我保存起的...
    Jenaral阅读 2,876评论 2 9
  • 一、简历准备 1、个人技能 (1)自定义控件、UI设计、常用动画特效 自定义控件 ①为什么要自定义控件? Andr...
    lucas777阅读 5,275评论 2 54
  • 可以说《刻意练习》这本书颠覆了俺的认知… 因为在听这本书之前我一直的理念就是,凡是在某一方面有所成就的。要么是天赋...
    波泥唐阅读 319评论 0 2
  • 午睡刚醒,赵娜定了定神,瞬间精神抖擞,翻身下床。这几天搬新房,家里的衣服、书本都要整理出来。 一想到这套新房,赵娜...
    萤烛之尚阅读 739评论 8 17