rpc可以说是一个分布式系统最基础的组件了。这里解析一下spark的内部rpc框架。
RpcEndpoint
RpcEndpoint
这个接口表示一个Rpc端点,只要继承了这个trait
,
就具备了收发Rpc消息的能力,主要包含以下方法
-
接收信息类
def receive: PartialFunction[Any, Unit]
一个偏函数,用来接受其他RpcEndpoint
发来的信息,其他类可以覆盖这个方法来重写接受信息的逻辑def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit]
方法和上面那个差不多,不过这个处理过逻辑之后可以返回一些信息
-
回调类
-
def onConnected(remoteAddress: RpcAddress): Unit
当有远程主机连接到这个RpcEndpoint
时的回调 -
onStart
,onStop
,onDisconnected
等回调
-
RpcEndpointRef
RpcEndpointRef
表示了一个远程RpcEndpoint
和当前端点的一个连接,如果想发送RPC消息给其他主机,可以先通过远程地址RpcAddress
(一个表示远程端点的case class)获取RpcEndpointRef
对象。通过这个对象发送RPC消息给远程节点。主要包括以下方法
异步发送请求
def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T]
这个方法发送任意的消息给远程端点,并返回一个Future
对象。当远端返回信息的时候可以从这个对象获取结果。同步发送请求
def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T
等待直到返回结果只发送信息
def send(message: Any): Unit
RpcEnv
这个接口可以说非常重要了,保存了所有的远程端点信息,而且负责RPC消息的分发。每一个RpcEndpoint
都有一个RpcEnv
对象。如果想要与其他RpcEndpoint
连接并收发信息,需要向远端RpcEndpoint
注册自己,远端RpcEndpoint
收到注册信息之后,会将请求连接的信息保存在RpcEnv
对象中,这样就算是两个RpcEndpoint
彼此连接上了(可以双向收发信息了)
-
Endpoint
的注册方法def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef
用来一个Endpoint
把自己注册到本地的RpcEnv
中。一个进程可能有多个Endpoint
比如说一个接收心跳信息的,还有一个用来监听Job的运行状态的,用来监听Executor
返回信息的等等。
RpcEndpoint
通过RpcEnv
发送信息给RpcEndpointRef
RpcEnv
内部将接收到的信息分发给注册在RpcEnv
中的RpcEndpoint
def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef]
异步注册def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef
同步注册
-
生命周期方法
stop
shutdown
awaitTermination
RpcCallContext
下面分析时会说,先贴出方法
private[spark] trait RpcCallContext {
/**
* Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]]
* will be called.
*/
def reply(response: Any): Unit
/**
* Report a failure to the sender.
*/
def sendFailure(e: Throwable): Unit
/**
* The sender of this message.
*/
def senderAddress: RpcAddress
}
spark 中使用了Netty实现了这些Rpc接口,下面看一看使用netty的实现。
NettyRpcEnvFactory
private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
def create(config: RpcEnvConfig): RpcEnv = {
val sparkConf = config.conf
// Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support
// KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance
val javaSerializerInstance =
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
val nettyEnv =
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
config.securityManager)
if (!config.clientMode) {
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
nettyEnv.startServer(config.bindAddress, actualPort)
(nettyEnv, nettyEnv.address.port)
}
try {
Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, config.name)._1
} catch {
case NonFatal(e) =>
nettyEnv.shutdown()
throw e
}
}
nettyEnv
}
}
用来创建NettyRpcEnv
对象一个工厂,创建了一个NettyRpcEnv
对象。
并启动了一个Netty服务器(nettyEnv.startServer
方法)
NettyRpcEnv
这个对象主要包含了一个Dispatcher
private[netty] class NettyRpcEnv(
val conf: SparkConf,
javaSerializerInstance: JavaSerializerInstance,
host: String,
securityManager: SecurityManager) extends RpcEnv(conf) with Logging {
...
private val dispatcher: Dispatcher = new Dispatcher(this)
...
private val transportContext = new TransportContext(transportConf,
new NettyRpcHandler(dispatcher, this, streamManager))
...
@volatile private var server: TransportServer = _
private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
...
def startServer(bindAddress: String, port: Int): Unit = {
.....
server = transportContext.createServer(bindAddress, port, bootstraps)
dispatcher.registerRpcEndpoint(
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
}
}
上面说到调用了startServer
方法
而这个方法内部则向dispatcher
对象注册了一个RpcEndpointVerifier
,这个对象其实也是一个RpcEndpoint
private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher: Dispatcher)
extends RpcEndpoint {
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RpcEndpointVerifier.CheckExistence(name) => context.reply(dispatcher.verify(name))
}
}
private[netty] object RpcEndpointVerifier {
val NAME = "endpoint-verifier"
/** A message used to ask the remote [[RpcEndpointVerifier]] if an `RpcEndpoint` exists. */
case class CheckExistence(name: String)
}
这里便是我们遇到的第一个RpcEndpoint
如果收到了CheckExistence
这个类型的信息则调用dispatcher
的verify
方法。
我们先看一下这个dispatcher
对象。
Dispatcher
这个对象的职责便是将收到的Rpc信息分发给不同的Endpoint,可以看到内部有一个ConcurrentHashMap
用来保存所有注册的RpcEndpoint
private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
private class EndpointData(
val name: String,
val endpoint: RpcEndpoint,
val ref: NettyRpcEndpointRef) {
val inbox = new Inbox(ref, endpoint)
}
private val endpoints: ConcurrentMap[String, EndpointData] =
new ConcurrentHashMap[String, EndpointData]
private val receivers = new LinkedBlockingQueue[EndpointData]
....
}
上面说到的registerRpcEndpoint
方法实际上将RpcEndpointVerifier
放入了这两个容器中。
RpcEndpointVerifier
则被其他Endpoint
用来判断自己是否被成功注册到这个RpcEnv
中。
远程Endpoint
发送一个包含自己名字的信息给这个RpcEnv
中的这个RpcEndpointVerifier
随后会检查保存Endpoint
信息的容器中是否包含注册信息,并将结果返回
NettyRpcEndpointRef
前面说过RpcEndpointRef
代表远端的Endpoint
,可以用来发送RPC信息
private[netty] class NettyRpcEndpointRef(
@transient private val conf: SparkConf,
private val endpointAddress: RpcEndpointAddress,
@transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) {
...
override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout)
}
}
让我们回到RpcEnv.ask
方法
private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
def onFailure(e: Throwable): Unit = { ... }
def onSuccess(reply: Any): Unit = reply match { ... }
try {
if (remoteAddr == address) {
val p = Promise[Any]()
p.future.onComplete {
case Success(response) => onSuccess(response)
case Failure(e) => onFailure(e)
}(ThreadUtils.sameThread)
dispatcher.postLocalMessage(message, p)
} else {
val rpcMessage = RpcOutboxMessage(message.serialize(this),
onFailure,
(client, response) => onSuccess(deserialize[Any](client, response)))
postToOutbox(message.receiver, rpcMessage)
promise.future.onFailure {
case _: TimeoutException => rpcMessage.onTimeout()
case _ =>
}(ThreadUtils.sameThread)
}
val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
override def run(): Unit = {
onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " +
s"in ${timeout.duration}"))
}
}, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
promise.future.onComplete { v =>
timeoutCancelable.cancel(true)
}(ThreadUtils.sameThread)
} catch { ... }
promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}
这个方法由3部分构成
第一部分:判断消息是否是发给本地注册的RpcEndpoint
的,是则发送本地信息
第二部分:如果是发给远程Endpoint
的,放到OutBox
里面,等待处理
第三部分:超时处理,起了一个定时任务,如果超时则报异常。同时给声明的Promise对象增加了一个回调,当rpc调用在超时前完成则取消之前起的定时任务。
我们首先看dispatcher.postLocalMessage
,这个方法封装了调用信息,
def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
val rpcCallContext =
new LocalNettyRpcCallContext(message.senderAddress, p)
val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
}
实际上走了dispatcher.postMessage
方法,实际做了3件事:
1.获取到EndpointData
对象
2.往这个对象的inbox对象发信息
3.将EndpointData
对象放入 receivers
队列中
private def postMessage(
endpointName: String,
message: InboxMessage,
callbackIfStopped: (Exception) => Unit): Unit ={
...
val data = endpoints.get(endpointName)
data.inbox.post(message)
receivers.offer(data)
...
}
inbox对象实际就保存了发往Endpoint
对象的信息。发到这里其实Endpoint
已经收到信息了。 但是post方法只是将消息放到队列里面,那么实际是怎么发送给Endpoint
的呢?
private[netty] class Inbox(
val endpointRef: NettyRpcEndpointRef,
val endpoint: RpcEndpoint)
extends Logging {
inbox => // Give this an alias so we can use it more clearly in closures.
@GuardedBy("this")
protected val messages = new java.util.LinkedList[InboxMessage]()
...
def post(message: InboxMessage): Unit = inbox.synchronized {
if (stopped) {
// We already put "OnStop" into "messages", so we should drop further messages
onDrop(message)
} else {
messages.add(message)
false
}
...
}
Dispatcher
对象里面有一个线程池,每个线程会不断的从receivers
队列中获取EndpointData
并处理其中的inbox
对象保存的信息
private val threadpool: ThreadPoolExecutor = {
val numThreads = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.numThreads",
math.max(2, Runtime.getRuntime.availableProcessors()))
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}
private class MessageLoop extends Runnable {
override def run(): Unit = {
try {
while (true) {
try {
val data = receivers.take()
if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
receivers.offer(PoisonPill)
return
}
data.inbox.process(Dispatcher.this)
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch {
case ie: InterruptedException => // exit
}
}
}
我们再回到inbox.process
方法
def process(dispatcher: Dispatcher): Unit = {
var message: InboxMessage = null
inbox.synchronized {
...
message = messages.poll()
...
}
while (true) {
safelyCall(endpoint) {
message match {
case RpcMessage(_sender, content, context) =>
try {
endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch { ... }
case OneWayMessage(_sender, content) =>
endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
case OnStart =>
endpoint.onStart()
if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
inbox.synchronized {
if (!stopped) {
enableConcurrent = true
}
}
}
case OnStop =>
val activeThreads = inbox.synchronized { inbox.numActiveThreads }
...
dispatcher.removeRpcEndpointRef(endpoint)
endpoint.onStop()
...
case RemoteProcessConnected(remoteAddress) =>
endpoint.onConnected(remoteAddress)
case RemoteProcessDisconnected(remoteAddress) =>
endpoint.onDisconnected(remoteAddress)
case RemoteProcessConnectionError(cause, remoteAddress) =>
endpoint.onNetworkError(cause, remoteAddress)
}
}
inbox.synchronized {
...
message = messages.poll()
if (message == null) {
numActiveThreads -= 1
return
}
}
}
}
可以看到这个方法不停的从messages队列中获取对象直到队列里面没有信息
之前发送给本地的Endpoint
的消息是InboxMessage
这个对应的模式匹配中的哪个对象呢?
private[netty] sealed trait InboxMessage
private[netty] case class OneWayMessage(
senderAddress: RpcAddress,
content: Any) extends InboxMessage
private[netty] case class RpcMessage(
senderAddress: RpcAddress,
content: Any,
context: NettyRpcCallContext) extends InboxMessage
private[netty] case object OnStart extends InboxMessage
private[netty] case object OnStop extends InboxMessage
之前发送的本地消息是RpcMessage
类型的,Inbox
和Endpoint
是一一对应的,所以会直接调用endpoint.receiveAndReply
方法进行相应的处理,也就是说这时候消息已经发送到Endpoint
了。(可以参考RpcEndpointVerifier.receiveAndReply
,这是其中一种RpcEndpoint
,在这个流程中可以理解为,本地的RpcEndpoint
向本地的RpcEnv
确认是否成功注册)
那么我们看一下发送消息给远程的RpcEndpoint
消息被封装成RpcOutboxMessage
,并调用了postToOutbox
方法
private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
if (receiver.client != null) {
message.sendWith(receiver.client)
} else {
...
val targetOutbox = {
val outbox = outboxes.get(receiver.address)
...
}
if (stopped.get) { ... } else {
targetOutbox.send(message)
}
}
}
private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
outbox => // Give this an alias so we can use it more clearly in closures.
@GuardedBy("this")
private val messages = new java.util.LinkedList[OutboxMessage]
@GuardedBy("this")
private var client: TransportClient = null
@GuardedBy("this")
private var connectFuture: java.util.concurrent.Future[Unit] = null
def send(message: OutboxMessage): Unit = {
val dropped = synchronized {
if (stopped) { ... } else {
messages.add(message)
false
}
}
if (dropped) { ... } else {
drainOutbox()
}
}
每个Outbox
里面包含
- 一个保存消息的队列
- 一个
TransportClient
连接远程的RpcEndpoint
并用来发送信息
drainOutbox
方法实际做了2件事
- 检查是否和远端的
RpcEndpoint
建立了连接,没有则起一个线程建立连接 - 遍历队列,发送信息给远端的
RpcEnv
的TransportServer
这个信息会被远端的NettyRpcHandler
处理
private[netty] class NettyRpcHandler(
dispatcher: Dispatcher,
nettyEnv: NettyRpcEnv,
streamManager: StreamManager) extends RpcHandler with Logging {
// A variable to track the remote RpcEnv addresses of all clients
private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]()
override def receive(
client: TransportClient,
message: ByteBuffer,
callback: RpcResponseCallback): Unit = {
val messageToDispatch = internalReceive(client, message)
dispatcher.postRemoteMessage(messageToDispatch, callback)
}
}
def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
val rpcCallContext =
new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
}
于是我们又看到了postMesage
这个方法,而这次是调用的远端的RpcEnv
中Dispatcher
的postMessage
,消息最后也会被发送给注册到远端的RpcEnv
中的RpcEndpoint
,这样远端的RpcEndpoint
便收到了来自本地的信息。完成了RPC通信。