事由
上周工作中遇到一个bug,现象是一个spark streaming的job会不定期地hang住,不退出也不继续运行。这个job经是用pyspark写的,以kafka为数据源,会在每个batch结束时将统计结果写入mysql。经过排查,我们在driver进程中发现有有若干线程都出于Sl状态(睡眠状态),进而使用gdb调试发现了一处死锁。
这是MySQLdb库旧版本中的一处bug,在此不再赘述,有兴趣的可以看这个issue。不过这倒是提起了我对另外一件事的兴趣,就是driver进程——严格的说应该是driver进程的python子进程——中的这些线程是从哪来的?当然,这些线程的存在很容易理解,我们开启了spark.streaming.concurrentJobs参数,有多个batch可以同时执行,每个线程对应一个batch。但翻遍pyspark的python代码,都没有找到有相关线程启动的地方,于是简单调研了一下pyspark到底是怎么工作的,做个记录。
本文概括
- Py4J的线程模型
- pyspark基本原理(driver端)
- CPython中的deque的线程安全
涉及软件版本
- spark: 2.1.0
- py4j: 0.10.4
Py4J
spark是由scala语言编写的,pyspark并没有像豆瓣开源的dpark用python复刻了spark,而只是提供了一层可以与原生JVM通信的python API,Py4J就是python与JVM之间的这座桥梁。这个库分为Java和Python两部分,基本原理是:
- Java部分,通过
py4j.GatewayServer
监听一个tcp socket(记做server_socket) - Python部分,所有对JVM中对象的访问或者方法的调用,都是通过
py4j.JavaGateway
向上面这个socket完成的。 - 另外,Python部分在创建
JavaGateway
对象时,可以选择同时创建一个CallbackServer
,它会在Python这册监听一个tcp socket(记做callback_socket),用来给Java回调Python代码提供一条渠道。 - Py4J提供了一套文本协议用来在tcp socket间传递命令。
pyspark driver工作流程
- 首先,一个spark job被提交后,如果被判定这是一个python的job,spark driver会找到相应的入口,即
org.apache.spark.deploy.PythonRunner
的main
函数,这个函数中会启动GatewayServer
// Launch a Py4J gateway server for the process to connect to; this will let it see our
// Java system properties and such
val gatewayServer = new py4j.GatewayServer(null, 0)
val thread = new Thread(new Runnable() {
override def run(): Unit = Utils.logUncaughtExceptions {
gatewayServer.start()
}
})
thread.setName("py4j-gateway-init")
thread.setDaemon(true)
thread.start()
- 然后,会创建一个Python子进程来运行我们提交上来的python入口文件,并把刚才
GatewayServer
监听的那个端口写入到子进程的环境变量中去(这样Python才知道要通过那个端口访问JVM)
// Launch Python process
val builder = new ProcessBuilder((Seq(pythonExec, formattedPythonFile) ++ otherArgs).asJava)
val env = builder.environment()
env.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
// pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
- Python子进程这边,我们是通过pyspark提供的python API编写的这个程序,在创建
SparkContext
(python)时,会初始化_gateway
变量(JavaGateway
对象)和_jvm
变量(JVMView
对象)
@classmethod
def _ensure_initialized(cls, instance=None, gateway=None, conf=None):
"""
Checks whether a SparkContext is initialized or not.
Throws error if a SparkContext is already running.
"""
with SparkContext._lock:
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway(conf)
SparkContext._jvm = SparkContext._gateway.jvm
if instance:
if (SparkContext._active_spark_context and
SparkContext._active_spark_context != instance):
currentMaster = SparkContext._active_spark_context.master
currentAppName = SparkContext._active_spark_context.appName
callsite = SparkContext._active_spark_context._callsite
# Raise error if there is already a running Spark context
raise ValueError(
"Cannot run multiple SparkContexts at once; "
"existing SparkContext(app=%s, master=%s)"
" created by %s at %s:%s "
% (currentAppName, currentMaster,
callsite.function, callsite.file, callsite.linenum))
else:
SparkContext._active_spark_context = instance
其中launch_gateway
函数可见pyspark/java_gateway.py
。
- 上面初始化的这个
_jvm
对象值得一说,在pyspark中很多对JVM的调用其实都是通过它来进行的,比如很多python种对应的spark对象都有一个_jsc
变量,它是JVM中的SparkContext对象在Python中的影子
,它是这么初始化的
def _initialize_context(self, jconf):
"""
Initialize SparkContext in function to allow subclass specific initialization
"""
return self._jvm.JavaSparkContext(jconf)
这里_jvm
为什么能直接调用JavaSparkContext
这个JVM环境中的构造函数呢?我们看JVMView
中的__getattr__
方法:
def __getattr__(self, name):
if name == UserHelpAutoCompletion.KEY:
return UserHelpAutoCompletion()
answer = self._gateway_client.send_command(
proto.REFLECTION_COMMAND_NAME +
proto.REFL_GET_UNKNOWN_SUB_COMMAND_NAME + name + "\n" + self._id +
"\n" + proto.END_COMMAND_PART)
if answer == proto.SUCCESS_PACKAGE:
return JavaPackage(name, self._gateway_client, jvm_id=self._id)
elif answer.startswith(proto.SUCCESS_CLASS):
return JavaClass(
answer[proto.CLASS_FQN_START:], self._gateway_client)
else:
raise Py4JError("{0} does not exist in the JVM".format(name))
self._gateway_client.send_command
其实就是向server_socket
发送访问对象请求的命令了,最后根据响应值生成不同类型的影子
对象,针对我们这里的JavaSparkContext
,就是一个JavaClass
对象。这个系列的类型还包括了JavaMember
,JavaPackage
等等,他们也通过__getattr__
来实现Java对象属性访问以及方法的调用。
- 我们刚才介绍Py4j时说过Python端在创建JavaGateway时,可以选择同时创建一个
CallbackClient
,默认情况下,一个普通的pyspark job是不会启动回调服务的,因为用不着,所有的交互都是Python --> JVM
这种模式的。那什么时候需要呢?streaming job就需要(具体流程我们稍后介绍),这就(终于!)引出了我们今天主要讨论的Py4J线程模型的问题。
Py4J线程模型
我们已经知道了Python与JVM双方向的通信分别是通过server_socket
和callack_socket
来完成的,这两个socket的处理模型都是多线程模型,即,每收到一个连接就启动一个线程来处理。我们只看Python --> JVM
这条通路的情况,另外一边是一样的
Server端(Java)
protected void processSocket(Socket socket) {
try {
this.lock.lock();
if(!this.isShutdown) {
socket.setSoTimeout(this.readTimeout);
Py4JServerConnection gatewayConnection = this.createConnection(this.gateway, socket);
this.connections.add(gatewayConnection);
this.fireConnectionStarted(gatewayConnection);
}
} catch (Exception var6) {
this.fireConnectionError(var6);
} finally {
this.lock.unlock();
}
}
继续看createConnection
:
protected Py4JServerConnection createConnection(Gateway gateway, Socket socket) throws IOException {
GatewayConnection connection = new GatewayConnection(gateway, socket, this.customCommands, this.listeners);
connection.startConnection();
return connection;
}
其中connection.startConnection
其实就是创建了一个新线程,来负责处理这个连接。
Client端(Python)
我们来看GatewayClient
中的send_command
方法:
def send_command(self, command, retry=True, binary=False):
"""Sends a command to the JVM. This method is not intended to be
called directly by Py4J users. It is usually called by
:class:`JavaMember` instances.
:param command: the `string` command to send to the JVM. The command
must follow the Py4J protocol.
:param retry: if `True`, the GatewayClient tries to resend a message
if it fails.
:param binary: if `True`, we won't wait for a Py4J-protocol response
from the other end; we'll just return the raw connection to the
caller. The caller becomes the owner of the connection, and is
responsible for closing the connection (or returning it this
`GatewayClient` pool using `_give_back_connection`).
:rtype: the `string` answer received from the JVM (The answer follows
the Py4J protocol). The guarded `GatewayConnection` is also returned
if `binary` is `True`.
"""
connection = self._get_connection()
try:
response = connection.send_command(command)
if binary:
return response, self._create_connection_guard(connection)
else:
self._give_back_connection(connection)
except Py4JNetworkError as pne:
if connection:
reset = False
if isinstance(pne.cause, socket.timeout):
reset = True
connection.close(reset)
if self._should_retry(retry, connection, pne):
logging.info("Exception while sending command.", exc_info=True)
response = self.send_command(command, binary=binary)
else:
logging.exception(
"Exception while sending command.")
response = proto.ERROR
return response
这里这个self._get_connection
是这么实现的
def _get_connection(self):
if not self.is_connected:
raise Py4JNetworkError("Gateway is not connected.")
try:
connection = self.deque.pop()
except IndexError:
connection = self._create_connection()
return connection
这里使用了一个deque
(也就是Python标准库中的collections.deque
)来维护一个连接池,如果有空闲的连接,就可以直接使用,如果没有,就新建一个连接。现在问题来了,如果deque不是线程安全的,那么这段代码在多线程环境就会有问题。那么deque是不是线程安全的呢?
deque的线程安全
当然是了,Py4J当然不会犯这样的低级错误,我们看标准库的文档:
Deques support thread-safe, memory efficient appends and pops from either side of the deque with approximately the same O(1) performance in either direction.
是线程安全的,不过措辞有点模糊,没有明确指出哪些方法是线程安全的,不过可以明确的是至少append的pop都是。之所以去查一下,是因为我也有点含糊,因为Python标准库还有另外一个Queue.Queue
,在多线程编程中经常使用,肯定是线程安全的,于是很容易误以为deque不是线程安全的,所以我们才要一个新的Queue。这个问题,推荐阅读stackoverflow上Jonathan的这个答案——他的回答不是被采纳的最高票,不过我认为他的回答比高票更有说服力
- 高票答案一直强调说
deque是线程安全的
这个事实是个意外,是CPython中存在GIL造成的,其他Python解释器就不一定遵守。关于这一点我是不认同的,deque在CPython中的实现确实依赖的GIL才变成了线程安全的,但deque的双端append的pop是线程安全的
这件事是白纸黑字写在Python文档中的,其他虚拟机的实现必须遵守,否则就不能称之为合格的Python实现。 - 那为什么还要有一个内部显式用了锁来做线程同步的
Queue.Queue
呢?Jonathan给出的回答是Queue
的put
和get
可以是blocking的,而deque
不行,这样一来,当你需要在多个线程中进行通信时(比如最简单的一个Producer - Consumer模式的实现),Queue
往往是最佳选择。
关于deque是否是线程安全这个问题,我将调研的结果写在了这个知乎问题的答案下Python中的deque是线程安全的吗?,就不在赘述了,这篇文章已经太长了。
关于Py4J线程模型的问题,还可以参考官方文档中的解释。
pyspark streaming与CallbackServer
刚才提到,如果是streaming的job,GatewayServer在初始化时会同时创建一个CallbackServer,提供JVM --> Python
这条通路。
@classmethod
def _ensure_initialized(cls):
SparkContext._ensure_initialized()
gw = SparkContext._gateway
java_import(gw.jvm, "org.apache.spark.streaming.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
# start callback server
# getattr will fallback to JVM, so we cannot test by hasattr()
if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
gw.callback_server_parameters.eager_load = True
gw.callback_server_parameters.daemonize = True
gw.callback_server_parameters.daemonize_connections = True
gw.callback_server_parameters.port = 0
gw.start_callback_server(gw.callback_server_parameters)
cbport = gw._callback_server.server_socket.getsockname()[1]
gw._callback_server.port = cbport
# gateway with real port
gw._python_proxy_port = gw._callback_server.port
# get the GatewayServer object in JVM by ID
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
# update the port of CallbackClient with real port
jgws.resetCallbackClient(jgws.getCallbackClient().getAddress(), gw._python_proxy_port)
# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
cls._transformerSerializer = TransformFunctionSerializer(
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
为什么需要这样呢?一个streaming job通常需要调用foreachRDD
,并提供一个函数,这个函数会在每个batch被回调:
def foreachRDD(self, func):
"""
Apply a function to each RDD in this DStream.
"""
if func.__code__.co_argcount == 1:
old_func = func
func = lambda t, rdd: old_func(rdd)
jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
api = self._ssc._jvm.PythonDStream
api.callForeachRDD(self._jdstream, jfunc)
这里,Python函数func
被封装成了一个TransformFunction
对象,在scala端spark也定义了同样接口一个trait
:
/**
* Interface for Python callback function which is used to transform RDDs
*/
private[python] trait PythonTransformFunction {
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
/**
* Get the failure, if any, in the last call to `call`.
*
* @return the failure message if there was a failure, or `null` if there was no failure.
*/
def getLastFailure: String
}
这样是Py4J提供的机制,这样就可以让JVM通过这个影子接口
回调Python中的对象了,下面就是scala中的callForeachRDD
函数,它把PythonTransformFunction
又封装了一层成为scala中的TransformFunction
, 但不管如何封装,最后都会调用PythonTransformFunction
接口中的call
方法完成对Python的回调。
/**
* helper function for DStream.foreachRDD(),
* cannot be `foreachRDD`, it will confusing py4j
*/
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
val func = new TransformFunction((pfunc))
jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
}
所以,终于要回答这个问题了,我们一开始看到的driver中的多个线程是怎么来的?
- python调用
foreachRDD
提供一个TranformFunction
给scala端 - scala端调用自己的
foreachRDD
进行正常的spark streaming作业 - 由于我们开启了
spark.streaming.concurrentJobs
,多个batch可以同时运行,这在scala端是通过线程池来进行的,每个batch都需要回调Python中的TranformFunction
,而按照我们之前介绍的Py4J线程模型,多个并发的回调会发现没有可用的socket连接而生成新的,而在CallbackServer(Python)这端,每个新连接都会创建一个新线程来处理。这样就出现了driver的Python进程中出现多个线程的现象。