行为序列建模:MIMN系列1——原理初探和源码解析

关键词行为序列建模MIMNRNN神经图灵机Attention

内容摘要

  • MIMN原理整体提要解析
  • MIMN源码速览
  • MIMN中参数维护方式总结
  • 在风控场景下,MIMN的训练,部署代码实战

本文主要是MIMN原理迅速扫描,实战部分见行为序列建模:MIMN系列2——消费Kafka实时预测代码实战


研究背景

本文受到字节跳动技术团队的一片博客《行为序列模型在抖音风控中的应用》的启发,在长序列建模中引入MIMN算法(Multi-channel user Interest Memory Network),进一步研究了阿里妈妈MIMN的论文和源码,将该算法成功部署到了风控业务系统,使得模型可以接受任意长度的历史序列对实体进行风险预测,同时引入外部存储记录在此之前所有的记忆状态,当有新的序列元素进入时,读写记录实时预测,简单而言相比于原始的通过滑窗限制序列长度的LSTM算法,MIMN具有两大优势:

  • 历史长序列建模:输入给模型的用户行为序列越长,理论上模型的效果越好,然而传统的RNN对历史长序列表征能力有限,而MIMN将历史信息的表征和Y值解耦,可以根据序列本身记录纯粹的历史所有记忆信息。
  • 实时增量预测:改变了部署方式,传统的RNN在实时预测时面临推理延迟和存储占用大的问题,MIMN采用外部存储记录最新的存量记忆状态,增量部分新来一个行为对接一个MIMN单元,读写修改维护状态,大大降低了在线部署实时推理的延迟和记忆存储的空间占用。

MIMN原理迅速概括

MIMN论文涉及好几个独立的知识点,作者的创新是将这些技术串起来解决了一个实际的问题,其中设计的子模块包括NTM神经图灵机MIU记忆感知单元DIN注意力网络三个知识点,本文对于这三块不做展开,只在整体层面介绍下几大模块的最用,以及内部参数的更新维护方式,原论文地址

(1)模型输入输出介绍

下面先从模型的输入开始了解MIMN,左侧橙色是增存量序列构建记忆的过程,右侧是在线部署时的预测部分。


模型架构

对于增存量记忆构建部分,输入是历史所有序列元素,每个元素包括物品id和物品的其他上下文信息拼接的结果,序列元素输入的目的是维护了一个M矩阵S矩阵,对于每一个用户都有它对应的M和S矩阵,每来一个新的序列元素,都会对M和S进行更新

  • M矩阵:负责对用户原始历史行为序列信息的表征,它通过NTM神经图灵机实现,通过读头和写头对NTM的结果进行更新
  • S矩阵:负责从M矩阵中提取高阶信息,配合目标物品进行DIN Attention从记忆中提取对目标有益的信息,弥补M矩阵的不足,它通过MIU模块实现

对于在线部署部分,输入是目标物品(Target Ad)历史记忆的读输出(Read Head)记忆感知模块和目标物品的Attention输出,以及其他上下文信息(Context Feas),四大输入拼接之后两层全连接在softmax得到0-1的输出,预测用户是否对目标物品有行为交互。

对于序列元素,是由历史到现在所有商品/广告形成的序列,细分的话有三种,一种是历史商品,一种是最后一个商品(或者是当前最新的一次行为商品),一种是目标商品(通过召回得到的候选商品),三个作用如下

  • 历史商品:用于刷存量构成S和M矩阵
  • 最后一个商品:用于在线部署阶段,触发UIC更新用户的S和M矩阵
  • 目标商品:用于在线部署阶段,调用UIC的M矩阵拿到读头输出,以及调用S矩阵进行Attention,从而输入全连接进行ctr预测

搞清楚三种元素的区别基本MIMN大体上吃透一半了。

(2)模型部署介绍

模型部署也分为增存量记忆维护,和线上预测两个部分

模型部署

虚线下面是增存量记忆维护,增量和存量的行为序列产出UIC Server的M矩阵和S矩阵以及其他记忆信息,没来一个新的序列元素就更新UIC的内容,不需要全部从头开始重新计算记忆信息。虚线上的在线预测部分,简单而言就是根据目标物品信息,用户静态信息,再去UIC中拿到无延迟的记忆信息,预测得到用户对目标响应概率。这两个流程是完全解耦的,相当于UIC对实时预测部分是无延迟的,不再像传统RNN那样维护历史序列id,而是维护一个历史到现在为止的记忆矩阵在外部存储即可。


MIMN源码速览

下面进一步了解MIMN都从源码开始,源码地址,源码比较复杂涉及一些其他算法,挑一些重点记录一下。

(1)主模型框架类

模型的主类是Model_MIMN

class Model_MIMN(Model):
    def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, MEMORY_SIZE, SEQ_LEN=400, Mem_Induction=0,
                 Util_Reg=0, use_negsample=False, mask_flag=False):
        super(Model_MIMN, self).__init__(n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE,
                                         BATCH_SIZE, SEQ_LEN, use_negsample, Flag="MIMN")
        self.reg = Util_Reg
...

该类继承Model类,Model类主要包含输入序列id的embedding映射过程和最后的全连接过程,NTM,MIU,DIN Attention全部在子类Model_MIMN中。

class Model(object):
    def __init__(self, n_uid, n_mid, EMBEDDING_DIM, HIDDEN_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample=False, Flag="DNN"):
        self.model_flag = Flag
        self.reg = False
        self.use_negsample = use_negsample
        with tf.name_scope('Inputs'):
        ...
        # Embedding layer
        with tf.name_scope('Embedding_layer'):
        ...
    # 基于之前网络的输出构造最后的全连接层
    def build_fcn_net(self, inp, use_dice=False):
        bn1 = tf.layers.batch_normalization(inputs=inp, name='bn1')
        ...

从功能上来说Model_MIMN的目的就是构造出最后一层全连接的输入inp,inp输入到全连接层,全连接包含batchNorm和两层全连接,和上图灰色的在线预测部分内容一致。

(2)MIMN单元

这是整个代码的核心,先看MIMN单元的实例化

cell = mimn.MIMNCell(controller_units=HIDDEN_SIZE, memory_size=MEMORY_SIZE, memory_vector_dim=2 * EMBEDDING_DIM,
                             read_head_num=1, write_head_num=1,
                             reuse=False, output_dim=HIDDEN_SIZE, clip_value=20, batch_size=BATCH_SIZE,
                             mem_induction=Mem_Induction, util_reg=Util_Reg)

在Model_MIMN中实例化了一个MIMN单元,而每一个序列的输入都会进这个MIMN单元,全局共享这个MIMN单元的模型参数,比如控制器和MIU中的GRU部分。在实例化MIMN单元的时候,这一段代码初始化了S矩阵

        if self.mem_induction > 0:
            self.channel_rnn = single_cell(self.memory_vector_dim)
            # TODO channel_rnn_state是S矩阵 [[256, 32], [256, 32], [256, 32], [256, 32]]
            self.channel_rnn_state = [self.channel_rnn.zero_state(batch_size, tf.float32) for i in range(memory_size)]
            self.channel_rnn_output = [tf.zeros(((batch_size, self.memory_vector_dim))) for i in range(memory_size)]

S矩阵为全0初始化,维度是[memory_size, batch_size, memory_dim],memory_size是记忆矩阵的高,memory_dim是记忆矩阵的宽,每个输入进来的样本都会有有一个自己的S矩阵。

下面初始化M矩阵的状态,当模型才开始训练和用户处于冷启动的时候,状态需要初始化,M矩阵比S矩阵复杂,会多一些相关的变量

state = cell.zero_state(BATCH_SIZE, tf.float32)

注意zero_state将BATCH_SIZE传进去,说明初始化和输入训练的用户数量有关,实际是每个用户都分配了一个初始化状态。举个例子看M矩阵的初始化

M = expand(
                tf.tanh(tf.get_variable('init_M', [self.memory_size, self.memory_vector_dim],
                                        initializer=tf.random_normal_initializer(mean=0.0, stddev=1e-5),
                                        trainable=False)),
                dim=0, N=batch_size)
def expand(x, dim, N):
    return tf.concat([tf.expand_dims(x, dim) for _ in range(N)], axis=dim)

对于每一个输入的用户,给他一个均值是0标准差是1e-5的随机(4,32)的初始化,然后复制batch_size(比如256)的份数,拼接成(256,4,32)的该batch下的init_M矩阵。由此可见虽然每个用户都给到一个单独的初始化M,但是他们初始化的结果是一模一样的,注意该变量trainable=False,不随着损失函数优化迭代。同理创建controller_state
,read_vector,w_list,M,key_M,w_aggre其他NTM需维护的变量,其中w_list包含了读头和写头。

(3)历史序列刷存量构建M和S矩阵

在MIMN单元实例化和MIMN的state初始化后,作者开始将历史200长度的序列灌入MIMN单元,代码如下

        for t in range(SEQ_LEN):
            output, state, temp_output_list = cell(self.item_his_eb[:, t, :], state)
            if mask_flag:
                # TODO mask的作用是修正状态,排除prepare阶段由于padding导致的state变动
                state = clear_mask_state(state, begin_state, begin_channel_rnn_output, self.mask, cell, t)
            # 记录下每个序列元素输出的output和status
            self.mimn_o.append(output)
            self.state_list.append(state)

代码里面通过item_his_eb[:, t, :]切片拿到了对应步长的序列元素,和当前的state一起输入MIMN单元,第一个元素对应的state是cell.zero_state得到的状态,后面的都是在循环中更新最新的state给下一个序列元素使用。注意这个for循环构造了一张tensorflow长图,及从第一个MIMN走到最后一个MIMN的路径,每一个样本,每一个批次进来的时候,都要经过这条路径,互不干扰,代码里面的self.state_list可以打印出来看一下,每一个样本的第一次state都是0初始化,不会存在参数继承的情况。
clear_mask_state函数是避免左边padding为0给state带来影响,代码如下

        def clear_mask_state(state, begin_state, begin_channel_rnn_state, mask, cell, t):
            # TODO mask[:, t] = [256, 1] => [256, 1]
            # TODO 如果mask是0相当于将controller_state重新置为begin_state,全0初始化,否则保持原样不变
            state["controller_state"] = (1 - tf.reshape(mask[:, t], (batch_size, 1))) * begin_state[
                "controller_state"] + tf.reshape(mask[:, t], (batch_size, 1)) * state["controller_state"]
            ...

以controller_state的计算为例,如果mask是0(代表padding了0),则左式保留controller_state打回原样成为begin_state,否则mask是1(代表不padding,是实际的序列元素),则左式删除,右式和state["controller_state"]没有差异保留模型对controller_state的更改。

(4)看看MIMN在做什么

下面深入这个cell(self.item_his_eb[:, t, :], state),看看MIMN在做什么,代码较长,挑提纲挈领的说。先看看这东西输入输出啥

def __call__(self, x, prev_state):
    return read_output, {
                "controller_state": controller_state,
                "read_vector_list": read_vector_list,
                "w_list": w_list,
                "M": M,
                "key_M": key_M,  # TODO key_M用完了之后没有修改
                "w_aggre": w_aggre,
                "sum_aggre": sum_aggre
             }, output_list

输入是当前步长的元素embedding和当前最新的state,输出是读M矩阵的输出,最新的状态,以及读S矩阵的输出,简单说一下三个输出的代码链路

  • 读M矩阵的输出:基于当前输入的序列元素,和上一个状态的读输出,经过NTM的控制器GRU,得到控制器输出,进一步计算得到读写之前记忆矩阵的w权重,通过该权重得到最新的读输出,和控制器输出拼接得到最终的read_output
  • 最新的状态:在读M矩阵的输出的计算过程中,同步记录下变动的state
  • 读S矩阵的输出:通过将当前步长的元素和上一个记忆矩阵输入多通道GRU,得到当前步长的读S矩阵的输出,同时更新S矩阵状态。

总结数据输入MIMN单元之后,输出读M和S矩阵的输出,以及更新M和S矩阵的参数状态,其中读M和S矩阵的输出要输入最后的全连接模型进行ctr预测,更新M和S矩阵的参数状态需要输入给下一个序列元素进行记忆更新来表征用户的行为。

(5)MIMN单元的后处理,构造主模型输入

MIMN的输出需要准备构造为最终主模型的输入的,首先用拥有最新的state的MIMN单元将目标商品灌进来走一边,拿到读输出,来表征原始记忆信息,第二第三全部不要,只要read_out

read_out, _, _ = cell(self.item_eb, state)

然后拿到现在最新的读S矩阵的输出,和目标商品一起输入给DIN Attention,提取高阶特征

        if Mem_Induction == 1:
            channel_memory_tensor = tf.concat(temp_output_list, 1)
            multi_channel_hist = din_attention(self.item_eb, channel_memory_tensor, HIDDEN_SIZE, None, stag='pal')
            # TODO read_out是读取M矩阵输出的结果,multi_channel_hist是读取S矩阵输出的结果,其他都是目标商品自身特征和上下文特征
            inp = tf.concat([self.item_eb, self.item_his_eb_sum, read_out, tf.squeeze(multi_channel_hist),
                             mean_memory * self.item_eb], 1)

最终的inp包含read_out, tf.squeeze(multi_channel_hist)这两大主要特征,以及其他上下文特征。


最终输入构造

在回过头来看图示,很清楚了呀,Target Ad拿到M的Read Head,同时和最新的S一起输入Attention。inp最终输入全连接进行ctr预测。整个代码的概览结束,里面复杂的NTM和DIN Attention先不展开研究。


MIMN参数维护方式总结

作者的代码是训练部分,该代码的目的仅仅是训练出控制器GRU,MIU的GRU,DIN以及其他几个全连接的参数,保存在tensorflow网络中,而S和M矩阵虽然在里面也产出了,但是真正部署上线肯定是重新刷历史存量所有序列得到的,而不是采用padding和截取200的方式,示意图如下

模型参数如何保存

其中NTM的读写w权重直接基于cos相似度计算得到,得到后直接更新M矩阵,不需要保存,其他记忆部分都是保存到外部存储自行维护,而右侧部分全部是tensorflow图来维护,不需要手动维护,在线上环节,读取外部存储拿到记忆参数,输入给tensorflow图即可完成预测。
另外看一下记忆参数是如何初始化,以及如何更新的


参数的初始化和更新方式

其中有的初始化是需要模型学习的,在部署的时候需要在训练的网络中将它恢复出来,否则初始化不一样,有些初始化是0初始化是写死的,相对而言方便一点。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,427评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,551评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,747评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,939评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,955评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,737评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,448评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,352评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,834评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,992评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,133评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,815评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,477评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,022评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,147评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,398评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,077评论 2 355

推荐阅读更多精彩内容