回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型的.fit()
中,即可在给定的训练阶段调用该函数集中的函数。虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类,回调函数只是习惯性称呼。下面是较常用的三个回调函数:
早停
tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=**None**, restore_best_weights=**False**)
当被监测的数量不再提升,则停止训练。
参数
monitor: 被监测的数据。
min_delta: 在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
patience: 没有进步的训练轮数,在这之后训练就会被停止。
verbose: 详细信息模式。
mode: {auto, min, max} 其中之一。 在
min
模式中, 当被监测的数据停止下降,训练就会停止;在max
模式中,当被监测的数据停止上升,训练就会停止;在auto
模式中,方向会自动从被监测的数据的名字中判断出来。baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。
restore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重。
学习率衰减
tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)
学习速率定时器。
参数
- schedule: 一个函数,接受轮索引数作为输入(整数,从 0 开始迭代) 然后返回一个学习速率作为输出(浮点数)。
- verbose: 整数。 0:安静,1:更新信息。
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0)
当标准评估停止提升时,降低学习速率。
当学习停止时,模型总是会受益于降低 2-10 倍的学习速率。 这个回调函数监测一个数据并且当这个数据在一定「有耐心」的训练轮之后还没有进步, 那么学习速率就会被降低。
参数
- monitor: 被监测的数据。
- factor: 学习速率被降低的因数。新的学习速率 = 学习速率 * 因数
- patience: 没有进步的训练轮数,在这之后训练速率会被降低。
- verbose: 整数。0:安静,1:更新信息。
-
mode: {auto, min, max} 其中之一。如果是
min
模式,学习速率会被降低如果被监测的数据已经停止下降; 在max
模式,学习塑料会被降低如果被监测的数据已经停止上升; 在auto
模式,方向会被从被监测的数据中自动推断出来。 - min_delta: 对于测量新的最优化的阀值,只关注巨大的改变。
- cooldown: 在学习速率被降低之后,重新恢复正常操作之前等待的训练轮数量。
- min_lr: 学习速率的下边界
训练过程可视化
tf.keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, batch_size=32, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, embeddings_data=None, update_freq='epoch')
Tensorboard 基本可视化。
TensorBoard 是由 Tensorflow 提供的一个可视化工具。
这个回调函数为 Tensorboard 编写一个日志, 这样你可以可视化测试和训练的标准评估的动态图像, 也可以可视化模型中不同层的激活值直方图。
如果你已经使用 pip 安装了 Tensorflow,你应该可以从命令行启动 Tensorflow:
tensorboard --logdir=/full_path_to_your_logs
参数
- log_dir: 用来保存被 TensorBoard 分析的日志文件的文件名。
- histogram_freq: 对于模型中各个层计算激活值和模型权重直方图的频率(训练轮数中)。 如果设置成 0 ,直方图不会被计算。对于直方图可视化的验证数据(或分离数据)一定要明确的指出。
- write_graph: 是否在 TensorBoard 中可视化图像。 如果 write_graph 被设置为 True,日志文件会变得非常大。
-
write_grads: 是否在 TensorBoard 中可视化梯度值直方图。
histogram_freq
必须要大于 0 。 - batch_size: 用以直方图计算的传入神经元网络输入批的大小。
- write_images: 是否在 TensorBoard 中将模型权重以图片可视化。
- embeddings_freq: 被选中的嵌入层会被保存的频率(在训练轮中)。
- embeddings_layer_names: 一个列表,会被监测层的名字。 如果是 None 或空列表,那么所有的嵌入层都会被监测。
- embeddings_metadata: 一个字典,对应层的名字到保存有这个嵌入层元数据文件的名字。 查看 详情 关于元数据的数据格式。 以防同样的元数据被用于所用的嵌入层,字符串可以被传入。
-
embeddings_data: 要嵌入在
embeddings_layer_names
指定的层的数据。 Numpy 数组(如果模型有单个输入)或 Numpy 数组列表(如果模型有多个输入)。 Learn ore about embeddings。 -
update_freq:
'batch'
或'epoch'
或 整数。当使用'batch'
时,在每个 batch 之后将损失和评估值写入到 TensorBoard 中。同样的情况应用到'epoch'
中。如果使用整数,例如10000
,这个回调会在每 10000 个样本之后将损失和评估值写入到 TensorBoard 中。注意,频繁地写入到 TensorBoard 会减缓你的训练。