问题描述
单独运行MNIST或者YOLO程序没有问题,但是如果还要在MNIST程序中运行YOLO就会各种报错。
问题成因分析
在从ckpt中导入运算图总会有一下语句:
import tensorflow as tf
sess = tf.Session()
saver = tf.train.import_meta_graph(
path_MNIST + 'MNIST.ckpt.meta')
saver.restore(sess, path_MNIST + 'MNIST.ckpt')
这样其实是把tf内部一个叫做tf.default_graph()
的变量做了修改,之后所有的图的定义都引用的是tf.default_graph()
里面的定义,所以如果出现:
import tensorflow as tf
import gc
import numpy as np
def MNIST_common(temp_buffer):
# tf.reset_default_graph()
sess = tf.Session()
# 普通回归
saver = tf.train.import_meta_graph(
path_MNIST + 'MNIST.ckpt.meta')
saver.restore(sess, path_MNIST + 'MNIST.ckpt')
feed_x = np.ndarray(shape=(1, 784), dtype=np.float32,
buffer=temp_buffer)
x = tf.get_default_graph().get_tensor_by_name(
'x:0')
predict = sess.run(tf.get_default_graph().get_tensor_by_name(
'y:0'), feed_dict={x: feed_x})
del saver
del x
del sess
gc.collect()
return predict
def MNIST_CNN(temp_buffer):
# tf.reset_default_graph()
sess = tf.Session()
# CNN
print('I am in **----****')
saver = tf.train.import_meta_graph(
path_CNNMNIST + 'MNISTCNN.ckpt.meta')
saver.restore(sess, path_CNNMNIST + 'MNISTCNN.ckpt')
print('----****-----')
feed_x_CNN = np.ndarray(shape=(1, 784), dtype=np.float32,
buffer=temp_buffer)
x = tf.get_default_graph().get_tensor_by_name(
'x-input:0')
keep_prob = tf.get_default_graph().get_tensor_by_name('keep_prob:0')
predict = sess.run(tf.get_default_graph().get_tensor_by_name('y_conv:0'), feed_dict={
x: feed_x_CNN, keep_prob: 0.8})
del saver
del x
del sess
gc.collect()
return predict
MNIST_common(temp_buffer)
MNIST_CNN(temp_buffer)
就会报错。
解决方法
-
tf.reset_default_graph()
在导入新的ckpt
文件前调用该函数就好,这个函数会清除tf.default_graph()
。 - 我尝试过虽然没有成功但是还是比较有启发性的解决方案
del tensorflow
reload(tensorflow)
- 把
tensorflow
从sys.moudles
删除 - 把
tensorflow
从sys.path
删除
会出现很有趣的现象,当我del tf
并且在sys.moudles
中删除tensorflow
时,报错
翻到报错的位置
???
del python
del core
然后我作死把这两行注释了,tensorflow完美运行。
???