一介绍
该部分主要介绍mnist数据集上的神经网络模型,变量管理,模型持久化这几部分。
二 变量管理
Tensorflow通过变量名称获取变量的机制主要是通过tf.get_variable和tf.variable_scope函数来现。 tf.get_variable创建变量时和Variable基本等价
get_variable和Variable不同在于,如果该变量名已经存在的化,会报错,但是Variable却不会报错,所以get_variable要获取变量时,需要通过variable_scope函数来生成一个上下文管理器。
三 模型持久化
模型持久化就是将模型保存,以方便复用。
model.ckpt.meta:保存tensorflow计算图的结构。
model.ckpt:保存了tensorflow程序中每一个变量的取值。
checkpoint: 这个文件中保存了一个目录下所有模型文件列表。
如果不想将tensorflow的网络结构重新一遍的化,可以直接加载,但是麻烦在于获取张量的方式。
为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个来指定需要保存或者加载的变量。比如在上面代码的例子中,想加载进v1变量,可以saver = tf.train.Saver([v1])这种方式,但是因为v2没有加载进去,所以会报错v2没有初始化的错误。
重命名加载的变量。
关于重命名的方式很适合上一章节讲述的滑动平均值,每一个变量的滑动平均值是通过影子变量维护的,所以要获取变量的滑动平均值就是获取这个影子变量的取值。
在滑动平均模型中有这个应用,提供了variables_to_restore函数来生成tf.train.Saver所需要的变量重命名字典。
使用tf.train.Saver()会保存运行tensorflow程序所需要的全部信息,然而有时并不需要某些信息,比如在测试或离线预测时,只需要知道神经网络从输入层到输出层即可,不需要变量初始化,模型保存等辅助信息。根据这些需求,tensorflow提供了convert_variables_to_constants函数,通过这个函数可以将计算图中的变量及取值通过常量的方式保存。
可以参考:保存,冻结,读取