人人都能懂的机器学习——用Keras搭建人工神经网络06

编译模型

在创建了模型之后,我们必须要使用compile()方法来指定损失方程和优化器。另外,你还可以指定在训练和评估过程中计算出一系列其他的指标(作为一个列表输入):

model.compile(loss="sparse_categorical_crossentropy",
              optimizer="sgd",
              metrics=["accuracy"])

这里使用下面的代码也能达到完全一样的效果:

model.compile(loss=keras.losses.sparse_categorical_crossentropy,
              optimizer=keras.optimizers.SGD(),
              metrics=[keras.metrics.sparse_categorical_accuracy])

未来我们还将使用更多其他的损失,优化器和指标,如果想看一下完整的列表,可以访问下面的网址:

https://keras.io/losses/

https://keras.io/optimizers/

https://keras.io/metrics/

我们来简单解释一下上面的代码,首先,使用sparse_categorical_crossentropy损失函数是因为我们有离散的标签(对于每个实例,只有一个目标类别,在这个模型中就是0到9),并且类别与类别之间是互斥的。如果我们对每个实例对每个类别都输出一个概率(就像一位有效编码one-hot vector一样,[0., 0., 0., 1., 0., 0.]表示3类别),那我们就用categorical_crossentropy损失就可以了。如果需要做二元分类(可能只有一个或者多个二元标签),那就使用sigmoid函数,而不是softmax激活函数,同时我们使用binary_crossentropy损失。

如果你想将离散的标签转换成one-hot vector使用keras.utils.to_categorical()函数。

至于优化器,sgd代表我们使用随机梯度下降法(Stochastic Gradient Descent)来训练模型。换句话说,Keras会执行之前的文章中所描述的反向传播算法。我们未来还会介绍更多高效的优化器(它们优化了梯度下降的方法,而并不是自动微分)。使用SGD优化器的时候,调整学习速度是十分重要的。所以一般会使用:

optimizer=keras.optimizers.SGD(lr=XXXXXX)

因为sgd的默认学习速度为0.01。

最后,既然这是一个分类器,那么衡量它的准确度accuracy是很有用的。

训练与评估模型

现在我们的模型已经准备好训练了,我们只需调用fit()方法就可以了:

>>> history = model.fit(X_train, y_train, epochs=30,
... validation_data=(X_valid, y_valid))
...
Train on 55000 samples, validate on 5000 samples
Epoch 1/30
55000/55000 [======] - 3s 49us/sample - loss: 0.7218 - accuracy: 0.7660
- val_loss: 0.4973 - val_accuracy: 0.8366
Epoch 2/30
55000/55000 [======] - 2s 45us/sample - loss: 0.4840 - accuracy: 0.8327
- val_loss: 0.4456 - val_accuracy: 0.8480
[...]
Epoch 30/30
55000/55000 [======] - 3s 53us/sample - loss: 0.2252 - accuracy: 0.9192
- val_loss: 0.2999 - val_accuracy: 0.8926

我们输入了特性(X_train)和目标类别(y_train),以及要训练的epoch次数(不然的话epoch默认为1,这肯定无法收敛至一个好的结果)。所谓epoch就是使用训练集中的全部数据对模型进行一次完整的训练,称为'一代训练'。不过这个中文叫法实在拗口,不如还是叫一个epoch。我们还输入了一个验证集(这是可选的)。Keras将在每个epoch结束时度量这些损失和其他指标,这对于查看模型的实际执行情况非常有用。如果模型在训练集上的表现比验证集上的表现好得多,那么模型可能对训练集进行了过拟合(或者可能存在bug,比如说训练集和验证集之间的数据不匹配)。

这样就可以了,我们模型已经训练好了!如果训练集和验证集的数据不匹配预期的输入形状的话,就会得到一个异常。这可能是最常见的错误,我们应该熟悉这个异常提示,比如:如果尝试使用一个包含扁平图像的向量,X_train.reshape(-1,784)。那么就会得到这样的异常提醒:“ValueError: Error when checking input: expected flatten_input to have 3 dimensions, but got array with shape (60000, 784)”。

在训练过程中的每一个epoch,Keras都会展示已经处理的实例数,以及一个处理进度条,每个实例训练的平均时间以及训练集和验证集的损失和准确度(或者还有其他设置好的指标)。你会看到训练损失下降了(这是个好现象),在30个epoch之后,验证集的准确度达到了89.26%,与训练集的准确度没有差距太大,所以看起来没有太大的过拟合发生。

上面的代码使用validation_data参数输入验证集,我们还可以使用validation_split的方式,设置一定比例的训练集用于验证集。比如,validation_split=0.1那么Keras将会使用最后10%的数据(在打乱顺序之前)用于验证。

如果训练集很不平衡,有些类别的实例过多而有些过少,那么应该在fit()方法中设置class_weight参数,这个参数会给实例较少的类别更大的权重,给实例过多的类别较小的权重。Keras将会在计算损失时考虑这些权重。如果需要设置每个实例的权重,可以设置sample_weight参数(它会取代class_weight)。如果有些实例是由专家标记的,而其他实例是使用公共数据平台标记的,那么每个实例的权重可能会很有用:你可能希望给专家标记的数据更多的权重。你也可以在validation_data的tuple中添加第三项,来给验证集设置实例权重。

fit()方法会返回一个History对象,这个对象包含训练参数(history.params),训练经过的epoch(history.history)以及最重要的,在每个epoch结束的时候对训练集和验证集度量的损失和其他指标的结果字典(history.history)。如果你用字典创建了Pandas的DataFrame,并且调用plot()方法,那么就会得到图1.12:

import pandas as pd
import matplotlib.pyplot as plt
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1) # set the vertical range to [0-1]
plt.show()
图1.12 学习曲线:每个epoch的平均训练损失和准确度,每个epoch结束时的平均验证损失和准确度.jpg

可以看到,在训练过程中,训练准确度和验证准确度都在稳步提高,而训练损失和验证损失都在减少。此外,验证曲线与训练曲线接近,这意味着没有太多的过拟合。在这个例子当中,在训练刚开始时模型在验证集上的表现似乎比训练集上要好。但事实并非如此:实际上,验证错误是在每个epoch结束时计算的,而训练错误是在每个epoch期间使用平均值计算的。所以训练曲线应该向左移动半个epoch。这样,你就可以发现训练和验证曲线在训练开始时几乎完全重合。

其实在绘制训练曲线时,也应当将其向左移动半个epoch。

通常情况下,如果训练时间足够长,训练集的表现最终会超过验证集的表现。可以看出这个例子当中,模型还没有完全收敛,验证损失仍然在下降,所以你应该继续训练。这与再次调用fit()方法一样简单,因为Keras会在它停止的地方继续进行训练(最终应该能够达到接近89%的验证准确度)。

如果对模型的表现不满意,那么应该重新调整超参数。首先要检查的是学习率。如果没什么用,尝试换一个优化器(并总是在更改任何超参数后重新调整学习率)。如果表现仍然不是很好,那么可以尝试调整模型的超参数,例如层的数量、每层神经元的数量以及每个隐藏层使用的激活函数。你还可以尝试调整其他超参数,比如批处理大小(可以在fit()方法中使用batch_size参数进行设置,默认值为32)。我们将在未来的文章当中再次讲到超参数调优。一旦您对模型的验证准确度感到满意,就应该在测试集上对其进行评估,以在将模型部署到生产环境之前估计泛化误差。使用evaluate()方法就可以评估测试集的准确度了(它还支持其他参数,比如batch_size以及sample_weight等):

>>> model.evaluate(X_test, y_test)
10000/10000 [==========] - 0s 29us/sample - loss: 0.3340 - accuracy: 0.8851
[0.3339798209667206, 0.8851]

通常模型在测试集中的表现要比验证集中略差一些,这是很正常的。因为超参数是根据验证集的表现来调整的(不过在这个例子里我们没有做任何超参的调整,所以准确度低只是运气不好)。记住,一定要抵制在测试集中调超参的诱惑,否则你会对泛化误差的估计过于乐观。

使用模型进行预测

接下来,我们可以使用模型的predict()方法对新的实例进行预测了。这里我们用测试集的前三个实例进行预测:

>>> X_new = X_test[:3]
>>> y_proba = model.predict(X_new)
>>> y_proba.round(2)
array([[0. , 0. , 0. , 0. , 0. , 0.03, 0. , 0.01, 0. , 0.96],
[0. , 0. , 0.98, 0. , 0.02, 0. , 0. , 0. , 0. , 0. ],
[0. , 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]],
dtype=float32)

对于每个实例,模型对每个类别都生成了一个概率。例如,对于第一个图像,它估计类别9(短靴)的概率是96%,类别5(凉鞋)的概率是3%,类别7(运动鞋)是1%,其他类别的概率忽略不计。

换句话说,模型“相信”第一个图片是鞋类,最有可能是短靴,但也不完全确定,可能是凉鞋或运动鞋。不过如果你只关心估计概率最高的类别(即使这个概率并不高),那么可以使用predict_classes()方法:

>>> y_pred = model.predict_classes(X_new)
>>> y_pred
array([9, 2, 1])
>>> np.array(class_names)[y_pred]
array(['Ankle boot', 'Pullover', 'Trouser'], dtype='<U11')

那么模型实际上对这三个图像都预测正确了(见图1.13)。

图1.13 测试集前三张图像.jpg

好的,那么现在我们已经学习了如何用Sequential API来搭建、训练、评估和使用分类MLP,那么接下来的文章我们将讲述如何搭建回归MLP以及搭建更加复杂的模型。

敬请期待啦!

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,546评论 6 507
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,224评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,911评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,737评论 1 294
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,753评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,598评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,338评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,249评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,696评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,888评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,013评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,731评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,348评论 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,929评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,048评论 1 270
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,203评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,960评论 2 355

推荐阅读更多精彩内容