四、深度学习文本分类TextCNN

原理:

TextCNN出处:论文https://aclanthology.org/D14-1181/

核心论点:

TextCNN

1.Represent sentence with static and non-static channels
2.Convolve with multiple filter widths and feature maps
3.Use max-over-time pooling
4.Use fully connected layer with dropout and softmax output

本文实现:

TextCNN的网络结构:


TextCNN的网络结构

模型构建与训练

定义网络结构

from tensorflow.keras import Input , Model
from tensorflow.keras.layers import Embedding ,Dense , Conv1D , GlobalMaxPooling1D,Concatenate,Dropout

class TextCNN(object):
    def __init__(self,maxlen , max_features , embedding_dims,class_num = 5 , last_activation = 'softmax'):
        self.maxlen = maxlen
        self.max_features = max_features
        self.embedding_dim = embedding_dims
        self.class_num  = class_num
        self.last_activation = last_activation

    def get_model(self):
        input = Input((self.max_len,))
        embdding = Embedding(self.max_features , self.embedding_dims , input_length = self.max_len)(input)
        convs = []
        for kernel_size in [3,4,5]:
            c = Con1D(128,kernel_size,activation = 'relu')(embedding)
            c = GlobalMaxPooling1D(c)
            convs.append(c)
        x = Concatenate()(convs)


        output = Dense(self.class_num , activation = self.last_activation)(x)
        model = Model(inputs = input,outputs = output)
        return model

数据处理与训练

from tensorflow.keras.proprecessing import sequence
import random 
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping,ModelCheckpoint
from tensorflow.keras.utils import to_categorical
from utils import *

#路径配置
data_dir = './processed_data'
vocab_file = './vocab/vocab.txt'
vocab_size = 40000

#神经网络配置
max_features = 40001
maxlen = 100
batch_size = 64
embedding_dims = 50
epochs = 8

print('数据预处理与加载数据...')
#如果不存在词汇表,重建
if not os.path.exists(vocab_file):
    build_vocab(data_dir , vocab_file , vocab_size)
#获得 词汇/类别 与id映射字典
categories , cat_to_id = read_category()
words , word_to_id = read_vocab(vocab_file)

#全部数据
x , y = read_files(data_dir)
data =   list(zip(x,y))
del x , y

#乱序
random.shuffle(data)

#切分训练集与测试集
train_data , test_data = train_test_split(data)
#对文本的词id和类别id进行编码
x_train = encode_sentences([content[0] for content in train_data] , word_to_id)
y_train = to_categorical(encode_cate([content[1] for content in train_data] , cat_to_id))
x_test = encode_sentences([content[0] for content in test_data] , word_to_id)
y_test = to_categorical(encode_cate([content[1] for content in test_data] , cat_to_id))

print('对序列做padding,保证是samples * timestep的维度')
x_train = sequence.pad_sequences(x_train , maxlen = maxlen)
x_test = sequence.pad_sequences(x_test , maxlen = maxlen)
print('x_train shape:' , x_train.shape)
print('x_test_shape:' , x_test.shape)

print('构建模型.....')
model = TextCNN(maxlen,max_features , embedding_dims).get_model()
model.compile('adam' , 'categorical_crossentropy' , metrics = ['accuracy'])

print('训练')
#设定callbacks回调函数
my_callbacks = [
    ModelCheckpoint('./cnn_model.h5',verbose = 1),
    EarlyStopping(monitor = 'val_accuracy' , patience = 2 , mode = 'max')
    ]

#fit拟合数据
history = model.fit(x_train , y_train, batch_size = batch_size ,epochs = epochs , callbacks = my_callbacks,validation_data = (x_test , y_test))

print('对测试集预测....')
result = model.predict(x_test)

训练中间信息输出

import matplotlib.pyplot as plt
plt.switch_bacakend('agg')
%matplotlib inline

fig1 = plt.figure()
plt.plot(history.history['loss'] , 'r' , linewidth = 3.0)
plt.plot(history.history['val_loss'],'b' , linewidth = 3.0)
plt.legend(['Training loss' , 'Validation Loss'] , fontsize = 18)
plt.xlabel('Epochs' , fontsize = 16)
plt.ylabel('Loss' , fontsize = 16)
plt.title('Loss Curves :CNN',fontsize = 16)
fig1.savefig('loss_cnn,png')
plt.show()
fig2=plt.figure()
plt.plot(history.history['acc'],'r',linewidth=3.0)
plt.plot(history.history['val_acc'],'b',linewidth=3.0)
plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
plt.xlabel('Epochs ',fontsize=16)
plt.ylabel('Accuracy',fontsize=16)
plt.title('Accuracy Curves : CNN',fontsize=16)
fig2.savefig('accuracy_cnn.png')
plt.show()

模型结构打印

from tensorflow.keras.utils import plot_model
# model.summary()
plot_model(model, show_shapes=True, show_layer_names=True)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,080评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,422评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,630评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,554评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,662评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,856评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,014评论 3 408
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,752评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,212评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,541评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,687评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,347评论 4 331
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,973评论 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,777评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,006评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,406评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,576评论 2 349

推荐阅读更多精彩内容