断点续训
- 读取模型
load_weights(路径文件名)
生成ckpt的同时会生成index文件,可通过该文件是否存在判断是否有预训练模型生成
ckpt_path = "./mnist.ckpt"
if(os.path.exists(ckpt_path + ".index")):
print("--load modle--")
model.load_weights(ckpt_path)
- 保存模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = 路径文件名,
save_weights_only=True/False, #只保留模型参数
save_best_only=True/False #只保留最优模型
)
history = model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
查看训练参数
- 提取可训练参数
model.trainable_variables - 设置print输出格式
np.set_printoptions(threshold=超过多少省略显示),此处若需要完全实现参数应设置为np.inf(表示无限大)
acc和loss可视化
#该可视化只可视化了当前运行的训练部分
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1,2,1)
plt.plot(acc,label='acc')
plt.plot(val_acc,label='val_acc')
plt.title('acc&&val_cac')
plt.legend()
plt.subplot(1,2,2)
plt.plot(loss,label='loss')
plt.plot(val_loss,label='val_loss')
plt.title('loss&&val_loss')
plt.legend()
plt.show()
前向传播
在训练完成后,使用网络生成预测结果
#复现模型
#加载参数
#前向传播获取结果
result = model.predict(x_predict)
在使用时x_predict需要在图片的原始维度前增加一个维度匹配batch维度