NLP保存与加载模型

加载组件

import os
import tensorflow as tf
from tensorflow import keras
from keras import Sequential
from keras.layers import Dense, Dropout
from keras.losses import SparseCategoricalCrossentropy
from keras.callbacks import ModelCheckpoint
from keras.metrics import SparseCategoricalAccuracy
from tensorflow import train

创建训练测试集

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

创建模型的函数

def create_model():
    model = Sequential([
        Dense(512, activation='relu', input_shape=(784,)),
        Dropout(0.2),
        Dense(10)
    ])
    model.compile(optimizer='adam',
                  loss=SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[SparseCategoricalAccuracy()])
    return model


# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()

model = create_model()

model.summary()

保存模型 (这里只保存模型的参数)

# Save checkpoints during training

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)

# Train the model with the new callback
# model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels), callbacks=[cp_callback])  # Pass callback to training

loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("accuracy: {:5.2f}%".format(100 * acc))

# build new model without training , its accuracy would be around 10%
model2 = create_model()
loss, acc = model2.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
# Untrained model, accuracy: 10.60%


# load wight for the new model
model2.load_weights(checkpoint_path)
loss, acc = model2.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
# Restored model, accuracy: 87.20%

保存规则的其他设置
这里可以在迭代中保存最优的一组参数,在超长时间的训练很有用。

# Checkpoint callback options

checkpoint_path_2 = 'training_2/cp-{epoch:04d}.ckpt'
checkpoint_dir_2 = os.path.dirname(checkpoint_path_2)

batch_size = 32

cp_callback2 = ModelCheckpoint(filepath=checkpoint_path_2, verbose=1, save_weights_only=True, save_freq=5 * batch_size)
# cp_callback = ModelCheckpoint(filepath=checkpoint_path_2, verbose=1, save_weights_only=True, save_best_only=True)
model3 = create_model()

# model3.save_weights(checkpoint_path_2.format(epoch=0))

# model3.fit(train_images, train_labels, epochs=50, batch_size=batch_size, callbacks=[cp_callback2], validation_data=(test_images, test_labels), verbose=0)

latest = train.latest_checkpoint(checkpoint_dir_2)
model4 = create_model()

print(latest)
model4.load_weights(latest)

# Re-evaluate the model
loss, acc = model4.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
# 32/32 - 0s - loss: 0.4787 - sparse_categorical_accuracy: 0.8750 - 95ms/epoch - 3ms/step
# Restored model, accuracy: 87.50%

最后是手动保存参数

# Manually save weights

#  after train, save down the weights
model4.save_weights('./checkpoints/my_checkpoint')

model5 = create_model()

# load the weights
model5.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model5.evaluate(test_images, test_labels, verbose=2)
print("Restored model5, accuracy: {:5.2f}%".format(100 * acc))

代码传送门

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 创建一个简单的模型理解句子某些词的语义(NER) 加载一些包 加载标签和语句 在ner文件夹里面有一堆原始数据,每...
    那个大螺丝阅读 2,750评论 0 0
  • 自然语言处理面试题 有哪些文本表示模型,它们各有什么优缺点? 词袋模型与N-gram  最基本的文本表示模型是词袋...
    Viterbi阅读 9,357评论 0 1
  • 本文另两篇系列 NLP的巨人肩膀(上) NLP的巨人肩膀(下) 3. 梯子的一级半 除了在word级别的embed...
    weizier阅读 11,667评论 0 18
  • 唯一需要注意的地方是,需要接两层BiLstm 来对数据进行降一个维度。如果不降维,会导致输出的矩阵形状与预设值不一...
    那个大螺丝阅读 3,269评论 0 1
  • 可以在训练期间和训练后保存模型进度。 这意味着模型可以从中断的地方恢复,并避免长时间的训练。 保存也意味着您可以共...
    AnuoF阅读 12,691评论 0 1