本文可以分为四个部分
- 简单的介绍一下graph的概念;
- 详细的介绍一下Session的概念;
- 介绍Session和graph之间的关系和交互概念
- 介绍使用keras的时候,session和graph是如何指定的
本文假设读者对于tensorflow 的工作流程有一个大致的认知,而非小白。因为在实际使用的过程中,在不同python文件中调来调去,我总是会不清楚session和graph的对应关系,想要弄清楚,因此写了这篇文章。里面的很多内容来自网络的博文,也有部分来自于官网的介绍,代码基本都是经过测试的。如果有理解不正确的地方,非常希望大家指正。
一、 Graph的概念
当tensorflow库被加载时,它会自动创建一个Graph对象,并将其作为默认的数据流图。因此,在Graph.as_default()上下文管理器之外定义的任何op, tensor对象都会自动放置在默认的数据流图中。
#如果希望获得默认的数据流图的句柄,可使用:
default_graph = tf.get_default_graph()
在大多数Tensorflow程序中,只使用默认图就可以了
二、Session的概念
1. 简单的session介绍
最常见的初学者模式中,我们这么使用session:
import tensorflow as tf
import numpy as np
a=tf.constant([[1,2],[3,5]])
b=tf.constant([[1,2],[3,4]])
c=tf.add(a,b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(c))
d=tf.matmul(a,b)
print(sess.run(d))
也就是说先在默认图中定义操作节点,然后生成一个session,在session中执行初始化以及执行图的流式操作。
2. tf.Session()
,tf.Session().as_default
和tf.get_default_session()
在tensorflow中关于session,还有一个默认session的概念,我们可以通过tf.get_default_session()
来获得当前线程的session,也可以通过tf.Session().as_default()
来指定某个sess为新进程的默认session。那么问题来了,tf.Session()
,tf.Session().as_default
,tf.get_default_session()
三者之间有什么区别和联系呢?
2.1 说说tf.Session()
和tf.Session().as_default
。
tf.Session()创建一个会话,当上下文管理器退出时会话关闭和资源释放自动完成。
tf.Session().as_default()
,看起来好像是创建了一个默认的session,这种session和不使用这种方法的session的区别是会话在上下文管理器退出时会话没有关闭,还可以通过调用会话进行run()和eval()操作。但是一定不要忘记手动关闭这个sess
import tensorflow as tf
a = tf.constant(1.0)
b = tf.constant(2.0)
with tf.Session() as sess:
print(a.eval())
print(b.eval(session=sess))
#------------------------------
#1.0
# RuntimeError: Attempted to use a closed Session.
import tensorflow as tf
a = tf.constant(1.0)
b = tf.constant(2.0)
with tf.Session().as_default() as sess:
print(a.eval())
print(b.eval(session=sess))
# ------------------------------
# 1.0
# 2.0
2.2 说说tf.Session().as_default
和tf.get_default_session()
那么通过tf.Session().as_default
指定的sess就是当前线程的默认sess了吗?不是的。通过一段代码来说明
import tensorflow as tf
with tf.Graph().as_default() as g:
with g.name_scope("myscope") as scope: # 有了这个scope,下面的op的name都是类似myscope/Placeholder这样的前缀
sess = tf.Session(target='', graph=g, config=None) # target表示要连接的tf执行引擎
a = tf.placeholder("float")
b = tf.placeholder("float")
y1 = a*b # 也可以写成a * b
operations = g.get_operations()
print(sess.run(y1, feed_dict={a: 3, b: 3})) # 9.0 feed_dictgraph中的元素和值的映射)
assert tf.get_default_session() is not sess
with sess.as_default(): # 把sess作为默认的session,那么tf.get_default_session就是sess, 否则不是
assert tf.get_default_session() is sess
print("================ default sess str ================")
assert tf.get_default_session() is not sess
print(tf.get_default_session().sess_str)
可以发现,在sess.as_default()
这上下文中,assert tf.get_default_session() is sess
判定正确。但是出了和这个上下文之后,通过assert tf.get_default_session() is not sess
可知,当前线程的默认session并不是sess了。
在前面我们看到,TensorFlow会自动生成一个默认的计算图,如果没有特殊指定,运算会自动加入这个计算图中。TensorFlow中的会话也有类似的机制,但是TensorFlow不会自动生成默认的会话。
在tf的源码中是这么说的
@tf_export("get_default_session")
def get_default_session():
"""Returns the default session for the current thread.
The returned `Session` will be the innermost session on which a
`Session` or `Session.as_default()` context has been entered.
NOTE: The default session is a property of the current thread. If you
create a new thread, and wish to use the default session in that
thread, you must explicitly add a `with sess.as_default():` in that
thread's function.
这里的意思是说,在最外层的tf.get_default_session()
得到的session是当前线程默认的session,是我们创建的tf.Session()上下文的入口,是当前线程的一个属性。当我们创建一个tf.Session(),实际启动了一个新的线程,我们指定的````with sess.as_default()```上下文实际上是一个新的线程并且指定了这个新线程的默认session。
这就解释了了为什么,在某个我们定义的session的上下文环境中,我们使用tf.get_default_session
的判定sess是对的,而离开这个上下文环境,使用tf.get_default_session is not sess
进行判定又是错的了。那么我们之所以可以sess.as_default()上下文管理器外面执行run()和eval(),是因为我们在另一个线程中执行的代码,所以手动关闭sess,也就是关闭这个线程。
三、Session和graph之间的关系和交互概念
最简单的一种情况,我们在默认图中定义tf的操作,然后在一个定义的sess中执行默认图的初始化操作和执行流式操作,得到结果。这种情况只有一个session和一个graph.
实际代码中碰到的情况远比这种教材上的情况复杂。我们会使用多个模型,有多个图,有多个会话。想要使用不同的模型就要将这些模型加载到不同session中,并且声明使用的时候申请是哪个session,从而避免由于session和想使用的模型不匹配导致错误,而使用多个graph就需要为每个graph使用不同的session,但是每个graph也可以在多个session中使用,这个时候就需要在每个session中使用的时候明确使用的graph。
在我们需要构建多个图的时候,我们可以将操作定义在不同图的上下文环境中。
import tensorflow as tf
graph = tf.get_default_graph()
g1 = tf.Graph() # 加载到Session 1的graph
g2 = tf.Graph() # 加载到Session 2的graph
sess1 = tf.Session(graph=g1) # Session1
sess2 = tf.Session(graph=g2) # Session2
# 加载第一个模型
with sess1.as_default():
assert tf.get_default_graph() is graph
with g1.as_default()():
assert tf.get_default_graph() is g1
# 加载第二个模型
with sess2.as_default(): # 1
with g2.as_default():
assert tf.get_default_graph() is g2
with sess1.as_default():
with sess1.graph.as_default(): # 2
assert tf.get_default_graph() is g1
with sess2.as_default():
assert tf.get_default_graph() is graph
with sess2.graph.as_default():
assert tf.get_default_graph() is g2
# 关闭sess
sess1.close()
sess2.close()
在使用as_default使session在离开的时候并不关闭,在后面可以继续使用直到手动关闭,由于有多个graph,所以sess.graph与tf.get_default_value的值是不相等的,因此在进入sess的时候必须sess.graph.as_default()明确什么sess.graph为当前默认graph,否则会报错.
通过tensorflow源码上的注释,我们再来讲讲上面的这段代码。
sess1 = tf.Session(graph=g1)
这个代码指定了创建session的执行的默认图。当在同一段进程中使用超过一个图的时候,建议对每个图使用不同的session,但是每个图可以在多个session中被使用。因此通过这种显式的方式来传递graph参数,可以让代码更加易读。
在sess.as_default的代码注释中,是这么说的
*N.B.* Entering a `with sess.as_default():` block does not affect
the current default graph. If you are using multiple graphs, and
`sess.graph` is different from the value of @{tf.get_default_graph},
you must explicitly enter a `with sess.graph.as_default():` block
to make `sess.graph` the default graph.
当使用with sess.as_default():
进入一个新的线程模块的时候,并不会影响当前的默认图,我们只能通过显式的声明with g1.as_default():
或者with sess1.graph.as_default():
才能进入这个sess指定的graph
四、使用keras.tensorflow_backend
如果用tensorflow写的model,一般来说每个model都有自己的session和graph
但是在keras,会经常忽略掉session和graph,这时候需要添加session和好几个地方加with graph,伪代码如下:
seg_graph = tf.Graph()
sess = tf.Session(graph=seg_graph)
K.set_session(sess)
K.get_session().run(tf.global_variables_initializer())
#保证代码
with seg_graph.as_default():
self.keras_model = self.build(mode=mode, config=config)
#上面一行代码会调用KM.Model
以及这类函数
topology.load_weights_fromXXX()
以及predict函数
在GPU上执行的时候,发现K.get_session().run(tf.global_variables_initializer())
很重要,如果没有这一步,那么在启动的时候,一直会包内存溢出问题,到底是什么原因,暂时还没有弄清楚,先mark一下。