此例子使用评论文本将影评分为积极(positive)或消极(nagetive)两类。这是一个二元(binary)或者二分类问题,一种重要且应用广泛的机器学习问题。
我们将使用来源于网络电影数据库(Internet Movie Database)的 IMDB 数据集(IMDB dataset),其包含 50,000 条影评文本。从该数据集切割出的25,000条评论用作训练,另外 25,000 条用作测试。训练集与测试集是平衡的(balanced),意味着它们包含相等数量的积极和消极评论。
%matplotlib inline
import tensorflow as tf
from tensorflow import keras
import numpy as np
#IMDB 数据集已经打包在 Tensorflow 中。该数据集已经经过预处理,评论(单词序列)已经被转换为整数序列,其中每个整数表示字典中的特定单词。
imdb = keras.datasets.imdb
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)
#每个标签都是一个值为 0 或 1 的整数值,其中 0 代表消极评论,1 代表积极评论。
print("Training entries: {}, labels: {}".format(len(train_data), len(train_labels)))
#电影评论可能具有不同的长度。以下代码显示了第一条和第二条评论的中单词数量。由于神经网络的输入必须是统一的长度,我们稍后需要解决这个问题。
len(train_data[0]), len(train_data[1])
#了解如何将整数转换回文本对您可能是有帮助的。这里我们将创建一个辅助函数来查询一个包含了整数到字符串映射的字典对象:
# 一个映射单词到整数索引的词典
word_index = imdb.get_word_index()
# 保留第一个索引
word_index = {k:(v+3) for k,v in word_index.items()}
word_index["<PAD>"] = 0
word_index["<START>"] = 1
word_index["<UNK>"] = 2 # unknown
word_index["<UNUSED>"] = 3
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
def decode_review(text):
return ' '.join([reverse_word_index.get(i, '?') for i in text])
#现在我们可以使用 decode_review 函数来显示首条评论的文本:
decode_review(train_data[0])
#由于电影评论长度必须相同,我们将使用 pad_sequences 函数来使长度标准化:
train_data = keras.preprocessing.sequence.pad_sequences(train_data,
value=word_index["<PAD>"],
padding='post',
maxlen=256)
test_data = keras.preprocessing.sequence.pad_sequences(test_data,
value=word_index["<PAD>"],
padding='post',
maxlen=256)
#并检查一下首条评论(当前已经填充):
print(train_data[0])
#构建模型
# 输入形状是用于电影评论的词汇数目(10,000 词)
vocab_size = 10000
model = keras.Sequential()
model.add(keras.layers.Embedding(vocab_size, 16))
model.add(keras.layers.GlobalAveragePooling1D())
model.add(keras.layers.Dense(16, activation='relu'))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.summary()
# 第一层是嵌入(Embedding)层。该层采用整数编码的词汇表,并查找每个词索引的嵌入向量(embedding vector)。这些向量是通过模型训练学习到的。向量向输出数组增加了一个维度。得到的维度为:(batch, sequence, embedding)。
# 接下来,GlobalAveragePooling1D 将通过对序列维度求平均值来为每个样本返回一个定长输出向量。这允许模型以尽可能最简单的方式处理变长输入。
# 该定长输出向量通过一个有 16 个隐层单元的全连接(Dense)层传输。
# 最后一层与单个输出结点密集连接。使用 Sigmoid 激活函数,其函数值为介于 0 与 1 之间的浮点数,表示概率或置信度。
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
#创建一个验证集
x_val = train_data[:10000]
partial_x_train = train_data[10000:]
y_val = train_labels[:10000]
partial_y_train = train_labels[10000:]
#训练模型
# 以 512 个样本的 mini-batch 大小迭代 40 个 epoch 来训练模型。
history = model.fit(partial_x_train,
partial_y_train,
epochs=40,
batch_size=512,
validation_data=(x_val, y_val),
verbose=1)
#评估模型
results = model.evaluate(test_data, test_labels, verbose=2)
print(results)
#创建一个准确率(accuracy)和损失值(loss)随时间变化的图表
history_dict = history.history
history_dict.keys()
import matplotlib.pyplot as plt
acc = history_dict['accuracy']
val_acc = history_dict['val_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']
epochs = range(1, len(acc) + 1)
# “bo”代表 "蓝点"
plt.plot(epochs, loss, 'bo', label='Training loss')
# b代表“蓝色实线”
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
plt.clf() # 清除数字
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
output1:
Training entries: 25000, labels: 25000
[ 1 14 22 16 43 530 973 1622 1385 65 458 4468 66 3941
4 173 36 256 5 25 100 43 838 112 50 670 2 9
35 480 284 5 150 4 172 112 167 2 336 385 39 4
172 4536 1111 17 546 38 13 447 4 192 50 16 6 147
2025 19 14 22 4 1920 4613 469 4 22 71 87 12 16
43 530 38 76 15 13 1247 4 22 17 515 17 12 16
626 18 2 5 62 386 12 8 316 8 106 5 4 2223
5244 16 480 66 3785 33 4 130 12 16 38 619 5 25
124 51 36 135 48 25 1415 33 6 22 12 215 28 77
52 5 14 407 16 82 2 8 4 107 117 5952 15 256
4 2 7 3766 5 723 36 71 43 530 476 26 400 317
46 7 4 2 1029 13 104 88 4 381 15 297 98 32
2071 56 26 141 6 194 7486 18 4 226 22 21 134 476
26 480 5 144 30 5535 18 51 36 28 224 92 25 104
4 226 65 16 38 1334 88 12 16 283 5 16 4472 113
103 32 15 16 5345 19 178 32 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0]
Model: "sequential_11"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding_1 (Embedding) (None, None, 16) 160000
_________________________________________________________________
global_average_pooling1d (Gl (None, 16) 0
_________________________________________________________________
dense_12 (Dense) (None, 16) 272
_________________________________________________________________
dense_13 (Dense) (None, 1) 17
=================================================================
Total params: 160,289
Trainable params: 160,289
Non-trainable params: 0
_________________________________________________________________
Epoch 1/40
30/30 [==============================] - 1s 27ms/step - loss: 0.6926 - accuracy: 0.5029 - val_loss: 0.6917 - val_accuracy: 0.5111
Epoch 2/40
30/30 [==============================] - 1s 23ms/step - loss: 0.6895 - accuracy: 0.5583 - val_loss: 0.6873 - val_accuracy: 0.6130
Epoch 3/40
30/30 [==============================] - 1s 21ms/step - loss: 0.6821 - accuracy: 0.6482 - val_loss: 0.6774 - val_accuracy: 0.6937
Epoch 4/40
30/30 [==============================] - 1s 19ms/step - loss: 0.6668 - accuracy: 0.7187 - val_loss: 0.6590 - val_accuracy: 0.7010
Epoch 5/40
30/30 [==============================] - 1s 19ms/step - loss: 0.6411 - accuracy: 0.7569 - val_loss: 0.6305 - val_accuracy: 0.7570
Epoch 6/40
30/30 [==============================] - 1s 21ms/step - loss: 0.6053 - accuracy: 0.7867 - val_loss: 0.5939 - val_accuracy: 0.7818
Epoch 7/40
30/30 [==============================] - 1s 21ms/step - loss: 0.5613 - accuracy: 0.8151 - val_loss: 0.5527 - val_accuracy: 0.8023
Epoch 8/40
30/30 [==============================] - 1s 21ms/step - loss: 0.5145 - accuracy: 0.8321 - val_loss: 0.5088 - val_accuracy: 0.8233
Epoch 9/40
30/30 [==============================] - 1s 21ms/step - loss: 0.4682 - accuracy: 0.8497 - val_loss: 0.4697 - val_accuracy: 0.8322
Epoch 10/40
30/30 [==============================] - 1s 21ms/step - loss: 0.4259 - accuracy: 0.8635 - val_loss: 0.4346 - val_accuracy: 0.8442
Epoch 11/40
30/30 [==============================] - 1s 19ms/step - loss: 0.3891 - accuracy: 0.8742 - val_loss: 0.4053 - val_accuracy: 0.8519
Epoch 12/40
30/30 [==============================] - 1s 20ms/step - loss: 0.3576 - accuracy: 0.8821 - val_loss: 0.3818 - val_accuracy: 0.8604
Epoch 13/40
30/30 [==============================] - 1s 19ms/step - loss: 0.3316 - accuracy: 0.8910 - val_loss: 0.3640 - val_accuracy: 0.8633
Epoch 14/40
30/30 [==============================] - 1s 19ms/step - loss: 0.3093 - accuracy: 0.8951 - val_loss: 0.3475 - val_accuracy: 0.8682
Epoch 15/40
30/30 [==============================] - 1s 22ms/step - loss: 0.2898 - accuracy: 0.9015 - val_loss: 0.3352 - val_accuracy: 0.8711
Epoch 16/40
30/30 [==============================] - 1s 20ms/step - loss: 0.2732 - accuracy: 0.9055 - val_loss: 0.3256 - val_accuracy: 0.8743
Epoch 17/40
30/30 [==============================] - 1s 22ms/step - loss: 0.2585 - accuracy: 0.9125 - val_loss: 0.3174 - val_accuracy: 0.8770
Epoch 18/40
30/30 [==============================] - 1s 22ms/step - loss: 0.2451 - accuracy: 0.9162 - val_loss: 0.3103 - val_accuracy: 0.8783
Epoch 19/40
30/30 [==============================] - 1s 21ms/step - loss: 0.2332 - accuracy: 0.9207 - val_loss: 0.3050 - val_accuracy: 0.8796
Epoch 20/40
30/30 [==============================] - 1s 23ms/step - loss: 0.2222 - accuracy: 0.9237 - val_loss: 0.3004 - val_accuracy: 0.8803
Epoch 21/40
30/30 [==============================] - 1s 22ms/step - loss: 0.2119 - accuracy: 0.9280 - val_loss: 0.2964 - val_accuracy: 0.8825
Epoch 22/40
30/30 [==============================] - 1s 21ms/step - loss: 0.2030 - accuracy: 0.9307 - val_loss: 0.2941 - val_accuracy: 0.8829
Epoch 23/40
30/30 [==============================] - 1s 19ms/step - loss: 0.1940 - accuracy: 0.9347 - val_loss: 0.2917 - val_accuracy: 0.8831
Epoch 24/40
30/30 [==============================] - 1s 19ms/step - loss: 0.1855 - accuracy: 0.9384 - val_loss: 0.2894 - val_accuracy: 0.8839
Epoch 25/40
30/30 [==============================] - 1s 19ms/step - loss: 0.1781 - accuracy: 0.9421 - val_loss: 0.2884 - val_accuracy: 0.8849
Epoch 26/40
30/30 [==============================] - 1s 20ms/step - loss: 0.1714 - accuracy: 0.9449 - val_loss: 0.2882 - val_accuracy: 0.8851
Epoch 27/40
30/30 [==============================] - 1s 19ms/step - loss: 0.1645 - accuracy: 0.9481 - val_loss: 0.2871 - val_accuracy: 0.8851
Epoch 28/40
30/30 [==============================] - 1s 19ms/step - loss: 0.1578 - accuracy: 0.9505 - val_loss: 0.2868 - val_accuracy: 0.8859
Epoch 29/40
30/30 [==============================] - 1s 19ms/step - loss: 0.1514 - accuracy: 0.9531 - val_loss: 0.2875 - val_accuracy: 0.8860
Epoch 30/40
30/30 [==============================] - 1s 22ms/step - loss: 0.1460 - accuracy: 0.9548 - val_loss: 0.2882 - val_accuracy: 0.8862
Epoch 31/40
30/30 [==============================] - 1s 21ms/step - loss: 0.1399 - accuracy: 0.9573 - val_loss: 0.2903 - val_accuracy: 0.8842
Epoch 32/40
30/30 [==============================] - 1s 20ms/step - loss: 0.1347 - accuracy: 0.9600 - val_loss: 0.2892 - val_accuracy: 0.8862
Epoch 33/40
30/30 [==============================] - 1s 20ms/step - loss: 0.1298 - accuracy: 0.9627 - val_loss: 0.2907 - val_accuracy: 0.8866
Epoch 34/40
30/30 [==============================] - 1s 21ms/step - loss: 0.1250 - accuracy: 0.9633 - val_loss: 0.2921 - val_accuracy: 0.8865
Epoch 35/40
30/30 [==============================] - 1s 22ms/step - loss: 0.1202 - accuracy: 0.9661 - val_loss: 0.2945 - val_accuracy: 0.8854
Epoch 36/40
30/30 [==============================] - 1s 22ms/step - loss: 0.1156 - accuracy: 0.9673 - val_loss: 0.2965 - val_accuracy: 0.8861
Epoch 37/40
30/30 [==============================] - 1s 23ms/step - loss: 0.1115 - accuracy: 0.9687 - val_loss: 0.2992 - val_accuracy: 0.8844
Epoch 38/40
30/30 [==============================] - 1s 21ms/step - loss: 0.1076 - accuracy: 0.9702 - val_loss: 0.3014 - val_accuracy: 0.8841
Epoch 39/40
30/30 [==============================] - 1s 22ms/step - loss: 0.1035 - accuracy: 0.9721 - val_loss: 0.3043 - val_accuracy: 0.8837
Epoch 40/40
30/30 [==============================] - 1s 22ms/step - loss: 0.0999 - accuracy: 0.9731 - val_loss: 0.3070 - val_accuracy: 0.8831
782/782 - 1s - loss: 0.3252 - accuracy: 0.8727
[0.32521477341651917, 0.8726800084114075]
</pre>