关键代码
第一步
tf.Variable
第二步
saver = tf.train.Saver()
第三步
saver.save或者saver.restore
保存变量
import sys
print(sys.version)
'''
3.5.3 |Continuum Analytics, Inc.| (default, May 15 2017, 10:43:23) [MSC v.1900 64 bit (AMD64)]
'''
import tensorflow as tf
import numpy as np
# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法
# 替换成下面的写法:
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, "save/save_net.ckpt")
print("Save to path: ", save_path)
"""
Save to path: my_net/save_net.ckpt
"""
提取变量
# 先建立 W, b 的容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
# 这里不需要初始化步骤 init= tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
# 提取变量
saver.restore(sess, "save/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
"""
weights: [[ 1. 2. 3.]
[ 3. 4. 5.]]
biases: [[ 1. 2. 3.]]
"""
在变量很多的情况下,每个变量都加name很麻烦,可以用下面这种形式
保存
with tf.variable_scope("regression"):
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32)
b = tf.Variable([[1,2,3]], dtype=tf.float32)
提取
with tf.variable_scope("regression"):
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32)
疑问?
def regression(x):
W = tf.Variable(tf.zeros([784, 10]), dtype=tf.float32)
b = tf.Variable(tf.zeros([10]), dtype=tf.float32)
y = tf.nn.softmax(tf.matmul(x, W) + b)
return y, [W, b]
恢复变量
with tf.variable_scope("regression"):
y1, variables = model.regression(x)
saver = tf.train.Saver(variables)
在恢复变量时,w和b必须指定dtype或者name,不然报错
但是下面这种情况就不用指定
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)