keras细谈Compile, Fit, Evaluate, Predict

实际例子代码下载:https://github.com/wennaz/Deep_Learning

3 Compile编译

在训练模型之前,我们需要配置学习过程,这是通过compile方法完成的。

他接收三个参数:优化器 optimizer, 损失函数 loss, 评估标准 metrics


#Compile函数定义:

compile(

    optimizer='', loss=None, metrics=None, loss_weights=None,

    weighted_metrics=None, run_eagerly=None, **kwargs

)

  • 优化器 optimizer:它可以是现有优化器的字符串标识符,如 rmsprop 或 adagrad,也可以是 Optimizer 类的实例。

  • 损失函数 loss:模型试图最小化的目标函数。它可以是现有损失函数的字符串标识符,如 categorical_crossentropy 或 mse,也可以是一个目标函数。

  • 评估标准 metrics:对于任何分类问题,你都希望将其设置为 metrics = ['accuracy']。评估标准可以是现有的标准的字符串标识符,也可以是自定义的评估标准函数。

3.1 优化器 optimizer

优化器 会将计算出的梯度应用于模型的变量,以使 loss 函数最小化。您可以将损失函数想象为一个曲面(见图 3),我们希望通过到处走动找到该曲面的最低点。梯度指向最高速上升的方向,因此我们将沿相反的方向向下移动。我们以迭代方式计算每个批次的损失和梯度,以在训练过程中调整模型。模型会逐渐找到权重和偏差的最佳组合,从而将损失降至最低。损失越低,模型的预测效果就越好。

优化算法在三维空间中随时间推移而变化的可视化效果。

(来源: 斯坦福大学 CS231n 课程,MIT 许可证,Image credit: Alec Radford)

optimizer.gif
Adadelta

一种随机梯度下降方法,它基于每个维度的自适应学习率来解决两个缺点:
整个培训期间学习率的持续下降
需要手动选择的整体学习率

Adagrad

Adagrad是一种优化器,具有特定于参数的学习率,相对于训练期间更新参数的频率进行调整。 参数接收的更新越多,更新越小。

Adam

Adam优化是一种基于随机估计的一阶和二阶矩的随机梯度下降方法。

根据Kingma等人的说法(2014年),该方法“计算效率高,内存需求少,不影响梯度的对角线重缩放,并且非常适合数据/参数较大的问题”。

Adamax

Adamax优化是基于无穷范数的Adam的变体。 默认参数遵循本文提供的参数。 Adammax有时优于Adam,特别是在带有嵌入Embedding的模型中

FTRL

实现FTRL算法的优化程序。请参阅本文的算法。此版本同时支持在线L2(上面的论文中给出了L2损失)和收缩类型L2(这是对损失函数增加L2损失)的支持。

RMSprop

该优化器通常是面对递归神经网络时的一个良好选择。

RMSprop的要旨是:

  • 保持梯度平方的移动(折后)平均值
  • 将梯度除以该平均值的根
  • RMSprop的此实现使用简单动量,而不使用Nesterov动量。
  • 居中版本还保留了梯度的移动平均值,并使用该平均值来估计方差。
  • momentum即动量,它模拟的是物体运动时的惯性,即更新的时候在一定程度上保留之前更新的方向,同时利用当前batch的梯度微调最终的更新方向。这样一来,可以在一定程度上增加稳定性,从而学习地更快,并且还有一定摆脱局部最优的能力
Nesterov

Nesterov 版本 Adam 优化器。 正像 Adam 本质上是 RMSProp 与动量 momentum 的结合, Nadam 是采用 Nesterov momentum 版本的 Adam 优化器。

SGD

随机梯度下降法,支持动量参数,支持学习衰减率,支持Nesterov动量。


3.2 损失函数 losses

损失函数的目的是计算模型在训练过程中寻求最小化。损失函数是模型优化的目标,所以又叫目标函数、优化评分函数。

损失函数分为三大类:概率损失,回归损失,合页损失。

概率损失 回归损失 合页损失
BinaryCrossentropy MeanSquaredError class Hinge class
CategoricalCrossentropy class MeanAbsoluteError class SquaredHinge class
SparseCategoricalCrossentropy class MeanAbsolutePercentageError class CategoricalHinge class
Poisson class CosineSimilarity class hinge function
binary_crossentropy function mean_squared_error function squared_hinge function
categorical_crossentropy function mean_absolute_error function categorical_hinge function
sparse_categorical_crossentropy function mean_absolute_percentage_error function
poisson function mean_squared_logarithmic_error function
KLDivergence class cosine_similarity function
kl_divergence function Huber class
huber function
LogCosh class
log_cosh function
BinaryCrossentropy

计算真实标签和预测标签之间的交叉熵损失。当只有两个标签类别(假定为0和1)时,请使用此交叉熵损失。

CategoricalCrossentropy

计算标签和预测之间的交叉熵损失。有两个或多个标签类别时,请使用此交叉熵损失函数。

CategoricalHinge

计算y_true和y_pred之间的分类合页损失。

CosineSimilarity

计算标签和预测之间的余弦相似度。

请注意,它是介于-1和0之间的负数,其中0表示正交性,而值接近-1则表示更大的相似性。 这使得它在尝试使预测值与目标值之间的接近度最大化的情况下可用作损失函数。

MeanAbsoluteError

计算标签和预测之间的绝对差的平均值。loss = abs(y_true - y_pred)

MeanAbsolutePercentageError

计算y_true和y_pred之间的平均绝对百分比误差。loss = 100 * abs(y_true - y_pred) / y_true

MeanSquaredError

计算标签和预测之间的误差平方的平均值。loss = square(y_true - y_pred)


3.3 评估标准 metrics

评价函数用于评估当前训练模型的性能。性能评估模块提供了一系列用于模型性能评估的函数,这些函数在模型编译时由metrics关键字设置 性能评估函数类似与目标函数, 只不过该性能的评估结果讲不会用于训练。

可以通过字符串来使用域定义的性能评估函数:


model.compile(loss='mean_squared_error',

              optimizer='sgd',

              metrics=['mae', 'acc'])

准确率评价 概率评价 回归评价 分类评价 合页评价
Accuracy class BinaryCrossentropy class MeanSquaredError class AUC class Hinge class
BinaryAccuracy class CategoricalCrossentropy class RootMeanSquaredError class Precision class SquaredHinge class
CategoricalAccuracy class SparseCategoricalCrossentropy class MeanAbsoluteError class TruePositives class CategoricalHinge class
TopKCategoricalAccuracy class KLDivergence class MeanAbsolutePercentageError class TrueNegatives class
SparseTopKCategoricalAccuracy class Poisson class MeanSquaredLogarithmicError class FalsePositives class
CosineSimilarity class FalseNegatives class
LogCoshError class PrecisionAtRecall class
SensitivityAtSpecificity class
SpecificityAtSensitivity class
Accuracy

计算预测等于标签的频率。

该度量创建两个局部变量,总计和计数,用于计算y_pred与y_true匹配的频率。 该频率最终以二进制精度返回:幂运算,将总数除以计数。

BinaryAccuracy

计算预测与二进制标签匹配的频率。该度量创建两个局部变量,总计和计数,用于计算y_pred与y_true匹配的频率。该频率最终以二进制精度返回:幂运算,将总数除以计数。

CategoricalAccuracy

计算预测与one-hot标签匹配的频率。该度量创建两个局部变量,总计和计数,用于计算y_pred与y_true匹配的频率。该频率最终以绝对精度返回:幂等运算,简单地将总数除以计数。

FalseNegatives

计算假阴性的数量。

FalsePositives

计算误报的数量。

Mean

计算给定值的(加权)平均值。

MeanAbsoluteError

计算标签和预测之间的平均绝对误差。

MeanSquaredError

计算y_true和y_pred之间的均方误差。


4 训练模型 fit

为模型训练固定的epochs(数据集上的迭代)。


# fit定义

fit(

    x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,

    validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,

    sample_weight=None, initial_epoch=0, steps_per_epoch=None,

    validation_steps=None, validation_batch_size=None, validation_freq=1,

    max_queue_size=10, workers=1, use_multiprocessing=False

)

  • x:输入数据。如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应于各个输入的numpy array
  • y:标签,numpy array
  • batch_size:整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。
  • epochs:整数,训练终止时的epoch值,训练将在达到该epoch值时停止,当没有设置initial_epoch时,它就是训练的总轮数,否则训练的总轮数为epochs - inital_epoch
  • verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
  • callbacks:list,其中的元素是keras.callbacks.Callback的对象。这个list中的回调函数将会在训练过程中的适当时机被调用,参考回调函数
  • validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。注意,validation_split的划分在shuffle之前,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。
  • validation_data:形式为(X,y)的tuple,是指定的验证集。此参数将覆盖validation_spilt。
  • shuffle:布尔值或字符串,一般为布尔值,表示是否在训练过程中随机打乱输入样本的顺序。若为字符串“batch”,则是用来处理HDF5数据的特殊情况,它将在batch内部将数据打乱。
  • class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)
  • sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode='temporal'。
  • initial_epoch: 从该参数指定的epoch开始训练,在继续之前的训练时有用。
  • fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况。


5 评估模型 evaluate

在测试模式下返回模型的误差值和评估标准值。计算逐批次进行。


# evaluate定义

evaluate(

    x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None,

    callbacks=None, max_queue_size=10, workers=1, use_multiprocessing=False,

    return_dict=False)

  • x:输入数据。它可能是:Numpy数组(或类似数组的数组)或数组列表(如果模型具有多个输入)。TensorFlow张量或张量列表(如果模型具有多个输入)。

  • y:目标数据。像输入数据x一样,它可以是Numpy数组或TensorFlow张量。它应该与x一致

  • batch_size:整数或无。 每批计算的样本数。 如果未指定,batch_size将默认为32。如果数据是以数据集,生成器或keras.utils.Sequence实例的形式(因为它们生成批次),则不要指定batch_size。

  • sample_weight:测试样本的可选Numpy权重数组,用于加权损失函数。

  • steps:整数或无。宣布评估阶段结束之前的步骤总数(样本批次)。

  • callbacks:评估期间要应用的回调列表。

  • max_queue_size:整数。仅用于generator或keras.utils.Sequence输入。生成器队列的最大大小。如果未指定,max_queue_size将默认为10。

  • workers:整数。仅用于generator或keras.utils.Sequence输入。使用基于进程的线程时,要启动的最大进程数。 如果未指定,worker将默认为1。如果为0,将在主线程上执行生成器。

  • use_multiprocessing:布尔值。仅用于generator或keras.utils.Sequence输入。如果为True,则使用基于进程的线程。 如果未指定,则use_multiprocessing将默认为False。 请注意,由于此实现依赖于多处理,因此不应将不可拾取的参数传递给生成器,因为它们无法轻易传递给子进程

  • return_dict:如果为True,则将损失和指标结果作为dict返回,每个键都是指标的名称。 如果为False,则将它们作为列表返回。


6 预测模型 predict

生成输入样本的输出预测。


# predict定义

predict(

    x, batch_size=None, verbose=0, steps=None,

    callbacks=None, max_queue_size=10,

    workers=1, use_multiprocessing=False

)

  • x:输入样本。它可能是:Numpy数组(或类似数组的数组)或数组列表(如果模型具有多个输入)。TensorFlow张量或张量列表(如果模型具有多个输入)。

  • batch_size:整数或无。 每批样品数。 如果未指定,batch_size将默认为32

  • verbose:详细模式,0或1。

  • steps:宣布预测回合完成之前的步骤总数(样本批次)。

  • callbacks:预测期间要应用的回调列表。

  • max_queue_size:整数。 仅用于generator或keras.utils.Sequence输入。 生成器队列的最大值。如果未指定,max_queue_size将默认为10。

  • workers: 仅用于generator或keras.utils.Sequence输入。 使用基于进程的线程时,要启动的最大进程数。 如果未指定,worker将默认为1。如果为0,将在主线程上执行生成器。

  • use_multiprocessing:布尔值。 仅用于generator或keras.utils.Sequence输入。 如果为True,则使用基于进程的线程。 如果未指定,则use_multiprocessing将默认为False。 请注意,由于此实现依赖于多处理,因此不应将不可拾取的参数传递给生成器,因为它们无法轻易传递给子进程。


7 模型保存

模型保存可以在训练期间和训练之后。这意味着模型可以从中断的地方继续进行,避免了长时间的训练。 保存还意味着您可以共享模型,其他人可以重新创建您的作品。

保存整个模型

您可以调用save_model函数 将整个模型保存到单个工件中。它将包括:

  • 模型的架构/配置

  • 模型的权重值(在训练过程中学习)

  • 模型的编译信息(如果调用了 compile())

  • 优化器及其状态(如果有的话,使您可以从上次中断的位置重新开始训练)


tf.keras.models.save_model(

    model, filepath, overwrite=True, include_optimizer=True, save_format=None,

    signatures=None, options=None

)

例子:


import tensorflow as tf

model = tf.keras.Sequential([tf.keras.layers.Dense(5, input_shape=(3,)),tf.keras.layers.Softmax()])

model.save('/tmp/model')

model.save("my_h5_model.h5")

保存架构

调用to_json()函数保存指定模型包含的层,以及这些层的连接方式。


from tensorflow import keras

model = keras.Sequential([keras.Input((32,)), keras.layers.Dense(1)])

json_config = model.to_json()

new_model = keras.models.model_from_json(json_config)

保存权重值

save_weights()函数可以选择仅保存和加载模型的权重。


# Runnable example

from tensorflow import keras

sequential_model = keras.Sequential(

    [

        keras.Input(shape=(784,), name="digits"),

        keras.layers.Dense(64, activation="relu", name="dense_1"),

        keras.layers.Dense(64, activation="relu", name="dense_2"),

        keras.layers.Dense(10, name="predictions"),

    ]

)

sequential_model.save_weights("weights.h5")

sequential_model.load_weights("weights.h5")

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,686评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,668评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,160评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,736评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,847评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,043评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,129评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,872评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,318评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,645评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,777评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,470评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,126评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,861评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,095评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,589评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,687评论 2 351

推荐阅读更多精彩内容