1.定义文件的保存路径
ckpt_dir="./ckpt_dir"
ifnotos.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
2.定义一个全局变量
global_step=tf.Variable(0,name='global_step',trainable=False)
这个全局变量是保存文件和提取文件的标识,比如我现在要load什么时候保存的变量
3.定义saver方法
saver=tf.train.Saver()
注意任何变量定义在saver前面的都会被保存,在其后面定义的都不会被保存
4.保存变量
注意看前面定义的变量global_step,第一步给这个变量更新值(epoch),然后再保存。所以这个变量是以后load哪个文件的依据
global_step.assign(i).eval()#set and update(eval) global_step with index, i
saver.save(sess, ckpt_dir+"/model.ckpt",global_step=global_step)
5.load变量
ckpt=tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)#restore all variables