从LSTM到GRU基于门控的循环神经网络总结

技术交流QQ群:1027579432,欢迎你的加入!

1.概述

  • 为了改善基本RNN的长期依赖问题,一种方法是引入门控机制来控制信息的累积速度,包括有选择性地加入新的信息,并有选择性遗忘之前累积的信息。下面主要介绍两种基于门控的循环神经网络:长短时记忆网络和门控循环单元网络。因为基本的RNN即\mathbf{h}_{t}=f\left(U \mathbf{h}_{t-1}+W \mathbf{x}_{t}+\mathbf{b}\right),每层的隐状态都是由前一层的隐状态经变换和激活函数得到的,反向传播求导时,最终得到的导数会包含每步梯度的连乘,会导致梯度爆炸或消失。所以,基本的RNN很难处理长期依赖问题,即无法学习到序列中蕴含的间隔时间较长的规律。

2.长短时记忆网络LSTM

  • 2.1长短时记忆网络是基本的循环神经网络的一种变体,可以有效的解决简单RNN的梯度爆炸或消失问题。LSTM网络主要改进在下面两个方面

    • 1.新的内部状态\mathbf{c}_{t}:LSTM网络引入一个新的内部状态\mathbf{c}_{t},专门进行线性的循环信息传递,同时输出信息给隐藏层的外部状态\mathbf{h}_{t}.
      \begin{aligned} \mathbf{c}_{t} &=\mathbf{f}_{t} \odot \mathbf{c}_{t-1}+\mathbf{i}_{t} \odot \tilde{\mathbf{c}}_{t} \\ \mathbf{h}_{t} &=\mathbf{o}_{t} \odot \tanh \left(\mathbf{c}_{t}\right) \end{aligned}
      符号说明:\mathbf{f}_{t}\mathbf{i}_{t}\mathbf{o}_{t}分别代表遗忘门、输入门、输出门用来控制信息传递的路径;⊙表示向量元素的点乘;\mathbf{c}_{t-1}表示上一时刻的记忆单元;\tilde{\mathbf{c}}_{t}表示通过非线性函数得到的候选状态。
      \tilde{\mathbf{c}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+U_{c} \mathbf{h}_{t-1}+\mathbf{b}_{c}\right)
      在每个时刻t,LSTM网络的内部状态\mathbf{c}_{t}记录了到当前时刻为止的历史信息。
    • 2.门控机制:LSTM网络引入了门控机制,用来控制信息传递的路径,\mathbf{f}_{t}\mathbf{i}_{t}\mathbf{o}_{t}分别代表遗忘门、输入门、输出门。这里的门概念类似于电路中的逻辑门概念,1表示开放状态,允许信息通过;0表示关闭状态,阻止信息通过。LSTM网络中的门是一个抽象的概念,借助sigmiod函数,使得输出值在(0,1)之间,表示以一定的比例运行信息通过。三个门的作用如下:
      • 遗忘门\mathbf{f}_{t}控制上一时刻的内部状态\mathbf{c}_{t-1}需要遗忘多少信息
      • 输入门\mathbf{i}_{t}控制当前时刻的候选状态\tilde{\mathbf{c}}_{t}有多少信息需要保存
      • 输出门\mathbf{o}_{t}控制当前时刻的内部状态\mathbf{c}_{t}有多少信息需要输出给外部状态\mathbf{h}_{t}
        \mathbf{f}_{t}=0, \mathbf{i}_{t}=1时,记忆单元\mathbf{c}_{t}将历史信息清空,并将候选状态向量\tilde{\mathbf{c}}_{t}写入。但此时记忆单元\mathbf{c}_{t}依然和上一时刻的历史信息相关。当\mathbf{f}_{t}=1, \mathbf{i}_{t}=0时,记忆单元将复制上一时刻的内容,不写入新的信息。三个门的计算公式如下:
        \begin{aligned} \mathbf{i}_{t} &=\sigma\left(W_{i} \mathbf{x}_{t}+U_{i} \mathbf{h}_{t-1}+\mathbf{b}_{i}\right) \\ \mathbf{f}_{t} &=\sigma\left(W_{f} \mathbf{x}_{t}+U_{f} \mathbf{h}_{t-1}+\mathbf{b}_{f}\right) \\ \mathbf{o}_{t} &=\sigma\left(W_{o} \mathbf{x}_{t}+U_{o} \mathbf{h}_{t-1}+\mathbf{b}_{o}\right) \end{aligned}
        其中,激活函数使用sigmoid函数,其输出区间是(0,1),\mathbf{x}_{t}表示当前时刻的输入,\mathbf{h}_{t-1}表示上一时刻的外部状态。
  • 2.2 LSTM网络的循环单元结构如下图所示,计算过程如下:

    • a.利用上一时刻的外部状态\mathbf{h}_{t-1}和当前时刻的输入\mathbf{x}_{t},计算出三个门,已经候选状态\tilde{\mathbf{c}}_{t}
      \begin{aligned} \mathbf{i}_{t} &=\sigma\left(W_{i} \mathbf{x}_{t}+U_{i} \mathbf{h}_{t-1}+\mathbf{b}_{i}\right) \\ \mathbf{f}_{t} &=\sigma\left(W_{f} \mathbf{x}_{t}+U_{f} \mathbf{h}_{t-1}+\mathbf{b}_{f}\right) \\ \mathbf{o}_{t} &=\sigma\left(W_{o} \mathbf{x}_{t}+U_{o} \mathbf{h}_{t-1}+\mathbf{b}_{o}\right) \end{aligned}

    \tilde{\mathbf{c}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+U_{c} \mathbf{h}_{t-1}+\mathbf{b}_{c}\right)

    • b.结合遗忘门\mathbf{f}_{t}和输入门\mathbf{i}_{t}来更新记忆单元\mathbf{c}_{t}
      \mathbf{c}_{t}=\mathbf{f}_{t} \odot \mathbf{c}_{t-1}+\mathbf{i}_{t} \odot \tilde{\mathbf{c}}_{t}
    • c.结合输出门\mathbf{o}_{t},将内部状态的信息传递给外部状态\mathbf{h}_{t}
      \mathbf{h}_{t}=\mathbf{o}_{t} \odot \tanh \left(\mathbf{c}_{t}\right)
LSTM Cell

3.门控循环单元网络GRU

  • GRU与LSTM的不同之处在于:GRU不引入额外的记忆单元\mathbf{c}_{t},GRU网络引入一个更新门来控制当前状态需要从历史状态中保留多少信息(不经过非线性变换),以及需要从候选状态中接收多少新的信息。
    \mathbf{h}_{t}=\mathbf{z}_{t} \odot \mathbf{h}_{t-1}+\left(1-\mathbf{z}_{t}\right) \odot g\left(\mathbf{x}_{t}, \mathbf{h}_{t-1} ; \theta\right)
    其中,\mathbf{z}_{t} \in[0,1]为更新门
    \mathbf{z}_{t}=\sigma\left(\mathbf{W}_{z} \mathbf{x}_{t}+\mathbf{U}_{z} \mathbf{h}_{t-1}+\mathbf{b}_{z}\right)
    在GRU网络中,函数g\left(\mathbf{x}_{t}, \mathbf{h}_{t-1} ; \theta\right)定义为:
    \tilde{\mathbf{h}}_{t}=\tanh \left(W_{h} \mathbf{x}_{t}+U_{h}\left(\mathbf{r}_{t} \odot \mathbf{h}_{t-1}\right)+\mathbf{b}_{h}\right)
    上式中的符号说明:\tilde{\mathbf{h}}_{t}表示当前时刻的候选状态,\mathbf{r}_{t} \in[0,1]为重置门,用来控制候选状态\tilde{\mathbf{h}}_{t}的计算是否依赖上一时刻的状态\mathbf{h}_{t-1}
    \mathbf{r}_{t}=\sigma\left(W_{r} \mathbf{x}_{t}+U_{r} \mathbf{h}_{t-1}+\mathbf{b}_{r}\right)
    \mathbf{r}_{t}=0时,候选状态\tilde{\mathbf{h}}_{t}=\tanh \left(W_{c} \mathbf{x}_{t}+\mathbf{b}\right)只和当前输入\mathbf{x}_{t}相关而与历史状态无关。当\mathbf{r}_{t}=1时,候选状态\tilde{\mathbf{h}}_{t}=\tanh \left(W_{h} \mathbf{x}_{t}+U_{h} \mathbf{h}_{t-1}+\mathbf{b}_{h}\right)和当前输入\mathbf{x}_{t}相关,也和历史状态\mathbf{h}_{t-1}相关,此时和简单的RNN是一样的。
    综合上述各式,GRU网络的状态更新方式为:
    \mathbf{h}_{t}=\mathbf{z}_{t} \odot \mathbf{h}_{t-1}+\left(1-\mathbf{z}_{t}\right) \odot \tilde{\mathbf{h}}_{t}
  • 总结:当\mathbf{z}_{t}=0, \mathbf{r}=1时,GRU网络退化为简单的RNN;若\mathbf{z}_{t}=0, \mathbf{r}=0时,当前状态\mathbf{h}_{t}只和当前输入\mathbf{x}_{t}相关,和历史状态\mathbf{h}_{t-1}无关。当\mathbf{z}_{t}=1时,当前状态\mathbf{h}_{t}等于上一时刻状态\mathbf{h}_{t-1}和当前输入\mathbf{x}_{t}无关。
    GRU Cell

3.实战:基于Keras的LSTM和GRU的文本分类

    import random
    import jieba
    import pandas as pd
    import numpy as np
    
    stopwords = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\stopwords.txt", index_col=False, quoting=3, sep="\t", names=["stopword"], encoding="utf-8")
    stopwords = stopwords["stopword"].values
    
    # 加载语料
    laogong_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beilaogongda.csv", encoding="utf-8", sep=",")
    laopo_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beilaopoda.csv", encoding="utf-8", sep=",")
    erzi_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beierzida.csv", encoding="utf-8", sep=",")
    nver_df = pd.read_csv(r"E:\DeepLearning\jupyter_code\dataset\corpus\03_project\beinverda.csv", encoding="utf-8", sep=",")
    
    # 删除语料的nan行
    laogong_df.dropna(inplace=True)
    laopo_df.dropna(inplace=True)
    erzi_df.dropna(inplace=True)
    nver_df.dropna(inplace=True)
    
    # 转换
    laogong = laogong_df.segment.values.tolist()
    laopo = laopo_df.segment.values.tolist()
    erzi = erzi_df.segment.values.tolist()
    nver = nver_df.segment.values.tolist()
    
    # 分词和去掉停用词
    
    ## 定义分词和打标签函数preprocess_text
    def preprocess_text(content_lines, sentences, category):
        # content_lines是上面转换得到的list
        # sentences是空的list,用来存储打上标签后的数据
        # category是类型标签
        for line in content_lines:
            try:
                segs = jieba.lcut(line)
                segs = [v for v in segs if not str(v).isdigit()]  # 除去数字
                segs = list(filter(lambda x: x.strip(), segs))  # 除去左右空格
                segs = list(filter(lambda x: len(x) > 1, segs))  # 除去长度为1的字符
                segs = list(filter(lambda x: x not in stopwords, segs))  # 除去停用词
                sentences.append((" ".join(segs), category))  # 打标签
            except Exception:
                print(line)
                continue
    
    # 调用上面函数,生成训练数据
    sentences = []
    preprocess_text(laogong, sentences, 0)
    preprocess_text(laopo, sentences, 1)
    preprocess_text(erzi, sentences, 2)
    preprocess_text(nver, sentences, 3)
    
    # 先打乱数据,使得数据分布均匀,然后获取特征和标签列表
    random.shuffle(sentences)  # 打乱数据,生成更可靠的训练集
    for sentence in sentences[:10]:    # 输出前10条数据,观察一下
        print(sentence[0], sentence[1])
    
    # 所有特征和对应标签
    all_texts = [sentence[0] for sentence in sentences]
    all_labels = [sentence[1] for sentence in sentences]
    
    
    # 使用LSTM对数据进行分类
    from keras.preprocessing.text import Tokenizer
    from keras.preprocessing.sequence import pad_sequences
    from keras.utils import to_categorical
    from keras.layers import Dense, Input, Flatten, Dropout
    from keras.layers import LSTM, Embedding, GRU
    from keras.models import Sequential
    
    
    # 预定义变量
    MAX_SEQENCE_LENGTH = 100   # 最大序列长度
    EMBEDDING_DIM = 200   # 词嵌入维度
    VALIDATION_SPLIT = 0.16   # 验证集比例
    TEST_SPLIT = 0.2  # 测试集比例
    
    # 使用keras的sequence模块文本序列填充
    tokenizer = Tokenizer()
    tokenizer.fit_on_texts(all_texts)
    sequences = tokenizer.texts_to_sequences(all_texts)
    word_index = tokenizer.word_index
    print("Found %s unique tokens." % len(word_index))
    
    
    data = pad_sequences(sequences, maxlen=MAX_SEQENCE_LENGTH)
    labels = to_categorical(np.asarray(all_labels))
    print("data shape:", data.shape)
    print("labels shape:", labels.shape)
    
    # 数据切分
    p1 = int(len(data) * (1 - VALIDATION_SPLIT - TEST_SPLIT))
    p2 = int(len(data) * (1 - TEST_SPLIT))
    
    # 训练集
    x_train = data[:p1]
    y_train = labels[:p1]
    
    # 验证集
    x_val = data[p1:p2]
    y_val = labels[p1:p2]
    
    # 测试集
    x_test = data[p2:]
    y_test = labels[p2:]
    
    # LSTM训练模型
    model = Sequential()
    model.add(Embedding(len(word_index) + 1, EMBEDDING_DIM, input_length=MAX_SEQENCE_LENGTH))
    model.add(LSTM(200, dropout=0.2, recurrent_dropout=0.2))
    model.add(Dropout(0.2))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(labels.shape[1], activation="softmax"))
    model.summary()
    
    # 模型编译
    model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=["acc"])
    print(model.metrics_names)
    
    model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, batch_size=128)
    model.save("lstm.h5")
    # 模型评估
    print(model.evaluate(x_test, y_test))
    
    
    
    # 使用GRU模型
    model = Sequential()
    model.add(Embedding(len(word_index) + 1, EMBEDDING_DIM, input_length=MAX_SEQENCE_LENGTH))
    model.add(GRU(200, dropout=0.2, recurrent_dropout=0.2))
    model.add(Dropout(0.2))
    model.add(Dense(64, activation="relu"))
    model.add(Dense(labels.shape[1], activation="softmax"))
    model.summary()
    
    model.compile(loss="categorical_crossentropy", optimizer="rmsprop", metrics=["acc"])
    print(model.metrics_names)
    
    model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, batch_size=128)
    model.save("gru.h5")
    
    print(model.evaluate(x_test, y_test))

4.本文代码及数据集下载

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,294评论 6 497
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,780评论 3 391
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,001评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,593评论 1 289
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,687评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,679评论 1 294
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,667评论 3 415
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,426评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,872评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,180评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,346评论 1 345
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,019评论 5 340
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,658评论 3 323
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,268评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,495评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,275评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,207评论 2 352

推荐阅读更多精彩内容