TensorFlow 保存和加载模型

可以在训练期间和训练后保存模型进度。 这意味着模型可以从中断的地方恢复,并避免长时间的训练。 保存也意味着您可以共享您的模型,而其他人可以重新创建您的工作。 在发布研究模型和技术时,大多数机器学习从业者分享:

  1. 用于创建模型的代码
  2. 模型的训练权重或参数

共享此数据有助于其他人了解模型的工作原理,并使用新数据自行尝试。

注意:小心不受信任的代码 - TensorFlow模型是代码。 有关详细信息,请参阅安全使用TensorFlow。

选项

保存TensorFlow模型有多种方法 - 取决于您使用的API。 本指南使用tf.keras,一个高级API,用于在TensorFlow中构建和训练模型。 有关其他方法,请参阅TensorFlow保存和还原指南或保存在急切中。

安装

安装和引用

安装和导入TensorFlow和依赖项,有下面两种方式:

  1. 命令行:pip install -q h5py pyyaml
  2. 在Anaconda Navigator中安装;

下载样本数据集

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras

tf.__version__

'1.11.0'

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

定义模型

让我们构建一个简单的模型,我们将用它来演示保存和加载权重。

# Returns a short sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10, activation=tf.nn.softmax)
  ])
  
  model.compile(optimizer=tf.keras.optimizers.Adam(), 
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=['accuracy'])
  
  return model


# Create a basic model instance
model = create_model()
model.summary()

在训练期间保存检查点

主要用例是在训练期间和训练结束时自动保存检查点。 通过这种方式,您可以使用训练有素的模型,而无需重新训练,或者在您离开的地方接受训练 - 以防止训练过程中断。

tf.keras.callbacks.ModelCheckpoint是执行此任务的回调。 回调需要几个参数来配置检查点。

检查点回调使用情况

训练模型并将模型传递给ModelCheckpoint:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 
 save_weights_only=True,
 verbose=1)

model = create_model()

model.fit(train_images, train_labels,  epochs = 10, 
  validation_data = (test_images,test_labels),
  callbacks = [cp_callback])  # pass callback to training

这将创建一个TensorFlow检查点文件集合,这些文件在每个时期结束时更新:

!ls {checkpoint_dir}

checkpoint cp.ckpt.data-00000-of-00001 cp.ckpt.index

创建一个新的未经训练的模型。 仅从权重还原模型时,必须具有与原始模型具有相同体系结构的模型。 由于它是相同的模型架构,我们可以共享权重,尽管它是模型的不同实例。

现在重建一个新的未经训练的模型,并在测试集上进行评估。 未经训练的模型将在偶然水平上执行(准确度约为10%):

model = create_model()

loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))

然后从检查点加载权重,并重新评估:

model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

1000/1000 [==============================] - 0s 40us/step
Restored model, accuracy: 87.60%

检查点回调选项

回调提供了几个选项,可以为生成的检查点提供唯一的名称,并调整检查点频率。

训练一个新模型,每5个时期保存一次唯一命名的检查点:

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.fit(train_images, train_labels,
  epochs = 50, callbacks = [cp_callback],
  validation_data = (test_images,test_labels),
  verbose=0)

现在,查看生成的检查点并选择最新的检查点:

! ls {checkpoint_dir}

checkpoint cp-0030.ckpt.data-00000-of-00001
cp-0005.ckpt.data-00000-of-00001 cp-0030.ckpt.index
cp-0005.ckpt.index cp-0035.ckpt.data-00000-of-00001
cp-0010.ckpt.data-00000-of-00001 cp-0035.ckpt.index
cp-0010.ckpt.index cp-0040.ckpt.data-00000-of-00001
cp-0015.ckpt.data-00000-of-00001 cp-0040.ckpt.index
cp-0015.ckpt.index cp-0045.ckpt.data-00000-of-00001
cp-0020.ckpt.data-00000-of-00001 cp-0045.ckpt.index
cp-0020.ckpt.index cp-0050.ckpt.data-00000-of-00001
cp-0025.ckpt.data-00000-of-00001 cp-0050.ckpt.index
cp-0025.ckpt.index

latest = tf.train.latest_checkpoint(checkpoint_dir)
latest

'training_2/cp-0050.ckpt'

注意:默认的tensorflow格式仅保存最近的5个检查点。

要测试,请重置模型并加载最新的检查点:

model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

1000/1000 [==============================] - 0s 96us/step
Restored model, accuracy: 86.80%

这些文件是什么?

上述代码将权重存储到检查点格式的文件集合中,这些文件仅包含二进制格式的训练权重。 检查点包含:*一个或多个包含模型权重的分片。 *索引文件,指示哪些权重存储在哪个分片中。

如果您只在一台机器上训练模型,那么您将有一个带有后缀的分片:.data-00000-of-00001

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

保存整个模型

整个模型可以保存到包含权重值,模型配置甚至优化器配置的文件中。 这允许您检查模型并稍后从完全相同的状态恢复培训 - 无需访问原始代码。

在Keras中保存功能齐全的模型非常有用 - 您可以在TensorFlow.js中加载它们,然后在Web浏览器中训练和运行它们。

Keras使用HDF5标准提供基本保存格式。 出于我们的目的,可以将保存的模型视为单个二进制blob。

model = create_model()

model.fit(train_images, train_labels, epochs=5)

# Save entire model to a HDF5 file
model.save('my_model.h5')

Epoch 1/5
1000/1000 [==============================] - 0s 395us/step - loss: 1.1260 - acc: 0.6870
Epoch 2/5
1000/1000 [==============================] - 0s 135us/step - loss: 0.4136 - acc: 0.8760
Epoch 3/5
1000/1000 [==============================] - 0s 138us/step - loss: 0.2811 - acc: 0.9280
Epoch 4/5
1000/1000 [==============================] - 0s 153us/step - loss: 0.2078 - acc: 0.9480
Epoch 5/5
1000/1000 [==============================] - 0s 154us/step - loss: 0.1452 - acc: 0.9750

现在从该文件重新创建模型:

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()

检查其准确性:

loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

这项技术可以保存以下:

  1. 权重值
  2. 模型的配置(架构)
  3. 优化器配置

Keras通过检查架构来保存模型。 目前,它无法保存TensorFlow优化器(来自tf.train)。 使用这些时,您需要在加载后重新编译模型,并且您将失去优化器的状态。

下一步是什么

这是使用tf.keras保存和加载的快速指南。

tf.keras指南显示了有关使用tf.keras保存和加载模型的更多信息。

请参阅在急切执行期间保存以备保存。

“保存和还原”指南包含有关TensorFlow保存的低级详细信息。

完整代码:

from __future__ import absolute_import,division,print_function
import os
import tensorflow as tf
from tensorflow import keras

print(tf.__version__)


# Download dataset
(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1,28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1,28 * 28) / 255.0

# Define a model
# Returns a short sequential model
def create_model():
    model = tf.keras.models.Sequential([
    keras.layers.Dense(512,activation=tf.nn.relu,input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10,activation=tf.nn.softmax)
])

model.compile(optimizer=tf.keras.optimizers.Adam(),
  loss=tf.keras.losses.sparse_categorical_crossentropy,
  metrics=['accuracy'])
return model

# Create a basic model instance
model = create_model()
model.summary()

# Checkpoint callback usage
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
 save_weights_only=True,
 verbose=1)
model = create_model()
model.fit(train_images,train_labels,epochs=10,
  validation_data=(test_images,test_labels),
  callbacks=[cp_callback]) # pass callback to training

# Create a new, untrained model. 
model = create_model()
loss,acc = model.evaluate(test_images,test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc))

# Load the weights from chekpoint, and re-evaluate.
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

# Train a new model, and save uniquely named checkpoints once every 5epochs
# include the epoch in the file name. (uses 'str.format')
checkpoint_path = 'training_2/cp-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1,save_weights_only=True,
    # Save weights, every 5-epochs
    period=5)

model = create_model()
model.fit(train_images,train_labels,
  epochs=50,callbacks = [cp_callback],
  validation_data = (test_images,test_labels),
  verbose=0)


latest = tf.train.latest_checkpoint(checkpoint_dir)
print(latest)


# To test, reset the model and load the latest checkpoint
model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

# Manually save weights
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss, acc = model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))


# Save the entire model
model = create_model()
model.fit(train_images,train_labels,
  epochs=5)
# Save entire model to a HDF5 file
model.save('my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('my_model.h5')
new_model.summary()

loss, acc = new_model.evaluate(test_images,test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 217,406评论 6 503
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,732评论 3 393
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 163,711评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,380评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,432评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,301评论 1 301
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,145评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,008评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,443评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,649评论 3 334
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,795评论 1 347
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,501评论 5 345
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,119评论 3 328
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,731评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,865评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,899评论 2 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,724评论 2 354

推荐阅读更多精彩内容

  • 近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用Tens...
    liuyan731阅读 12,754评论 0 19
  • 在这篇tensorflow教程中,我会解释: 1) Tensorflow的模型(model)长什么样子? 2) 如...
    JunsorPeng阅读 3,419评论 1 6
  • 世界这么大,你应该去看看 今天刚刚高考完的表妹问起大学报考志愿应该怎么填,因为发挥的不太好,家里人都建议她学护理,...
    Miss凌妹妹阅读 455评论 6 4
  • 近海风云烈, 征帆拓远洲。 迷雾隐奇伟, 但去必贤优。
    村客阅读 153评论 0 6
  • 你说你不吃香菜 我说我吃饺子不带汤 我俩的碗里却是漂着香菜的饺子汤 在南方第一次吃热干面,我嫌它太噎人 在北方,后...
    Crazy麻麻阅读 346评论 7 9