本文的介绍以DirectDStream为例进行介绍
启动sparkStreaming的背压
-- conf spark.streaming.backpressure.enabled = true \ # 开启sparkStreaming的背压
--conf spark.streaming.kafka.maxRatePerPartition = 100 \ # 每个partition每秒最多消费的数据条目数
-- conf spark.streaming.kafka.initialRate = 100 \ # 初始化的第一个批次每个partition最大的消费速率
涉及类
- RateController: 背压入口,了实现StreamingListener特质,并重写了OnBatchComplete方法
- RateEstimator: 基于每个批次的完成,估算inputDStream应该摄取的速度
- PIDRateEstimator:实际的实现
如何生效呢?
在程序运行到StreamingContext的start方法时会调用JobScheduler的start方法,在这里会根据消费者的不同生成不同的RateController,在kafka中生成的是DirectKafkaRateController实例。接下来会把生成的RateController注册到StreamingListenerBus中。
- 首先来看DirectKafkaInputDStream
/**
* Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
*/
override protected[streaming] val rateController: Option[RateController] = {
// 判断背压是否开启,如果开启则创建一个RateEstimator
if (RateController.isBackPressureEnabled(ssc.conf)) {
Some(new DirectKafkaRateController(id,
RateEstimator.create(ssc.conf, context.graph.batchDuration)))
} else {
None
}
}
/**
* A RateController to retrieve the rate from RateEstimator.
*/
private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator)
extends RateController(id, estimator) {
override def publish(rate: Long): Unit = ()
}
/**
* Return a new `RateEstimator` based on the value of
* `spark.streaming.backpressure.rateEstimator`.
*
* The only known and acceptable estimator right now is `pid`.
*
* @return An instance of RateEstimator
* @throws IllegalArgumentException if the configured RateEstimator is not `pid`.
*/
def create(conf: SparkConf, batchInterval: Duration): RateEstimator =
conf.get("spark.streaming.backpressure.rateEstimator", "pid") match {
case "pid" =>
val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0)
val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2)
val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0)
val minRate = conf.getDouble("spark.streaming.backpressure.pid.minRate", 100)
new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate)
case estimator =>
throw new IllegalArgumentException(s"Unknown rate estimator: $estimator")
}
- 以上,构建了一个DirectKafkaRateController, 并传入了一个PIDRateEstimator对象,并且这里的DirectKafkaRateController没有做任何事情,只是将PIDRateEstimator传入
RateController
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
val elements = batchCompleted.batchInfo.streamIdToInputInfo
/**
* batchCompleted:streaming的批次完成情况
* processingEnd:批次执行结束结束时间
* workDelay:当前批次处理消耗的时间
* waitDelay:调度时间
* elems:批次消费的数据量
*/
for {
processingEnd <- batchCompleted.batchInfo.processingEndTime
workDelay <- batchCompleted.batchInfo.processingDelay
waitDelay <- batchCompleted.batchInfo.schedulingDelay
elems <- elements.get(streamUID).map(_.numRecords)
} computeAndPublish(processingEnd, elems, workDelay, waitDelay)
}
/**
* Compute the new rate limit and publish it asynchronously.
*/
private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit =
Future[Unit] {
/**
* 计算得到新的速率
*/
val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay)
newRate.foreach { s =>
// 根据新的速率设置rateLimit, 并发布
rateLimit.set(s.toLong)
publish(getLatestRate())
}
}
pidRateEstimator
- 那么这个新速率是怎么算出来的呢?
def compute(
time: Long, // in milliseconds 批次处理结束时间
numElements: Long, // 处理数据量
processingDelay: Long, // in milliseconds 处理耗时
schedulingDelay: Long // in milliseconds 调度耗时
): Option[Double] = {
logTrace(s"\ntime = $time, # records = $numElements, " +
s"processing time = $processingDelay, scheduling delay = $schedulingDelay")
this.synchronized {
// 一些校验: 当前批次结束时间 > 上一次结束时间, 数据量 >0 处理时间>0
if (time > latestTime && numElements > 0 && processingDelay > 0) {
// in seconds, should be close to batchDuration
// 两个批次处理完的时间间隔
val delaySinceUpdate = (time - latestTime).toDouble / 1000
// in elements/second
// 数据量/处理时间 = 处理速率
val processingRate = numElements.toDouble / processingDelay * 1000
// in elements/second
// 上一次处理速率 - 本次处理速率 = 消费速率差
val error = latestRate - processingRate
// (in elements/second)
// 调度时间 * 消费速度 / 批次时间 = 调度时间 / 批次时间 占比 * 处理速率 = x(因为调度时间导致的消费减少)
val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis
// in elements/(second ^ 2)
// 速率差 - 上次的速率差 / 批次处理完的时间间隔 = 速率差的导数
val dError = (error - latestError) / delaySinceUpdate
// 新速率 = 上一次的速率 - 1 * 速度差 - 0.2 * x - 0 * 速率差的导数
val newRate = (latestRate - proportional * error -
integral * historicalError -
derivative * dError).max(minRate)
logTrace(s"""
| latestRate = $latestRate, error = $error
| latestError = $latestError, historicalError = $historicalError
| delaySinceUpdate = $delaySinceUpdate, dError = $dError
""".stripMargin)
latestTime = time
// 判断是否是第一次运行
if (firstRun) {
latestRate = processingRate
latestError = 0D
firstRun = false
logTrace("First run, rate estimation skipped")
None
} else {
latestRate = newRate
latestError = error
logTrace(s"New rate = $newRate")
Some(newRate)
}
} else {
logTrace("Rate estimation skipped")
None
}
}
}
应用
- 回到DirectDStream,我们怎么利用这个计算出来的消费速度呢?
/**
* 计算每个分区的最大消费数据量
*/
protected[streaming] def maxMessagesPerPartition(
offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = {
// 获取最新的消费速率
val estimatedRateLimit = rateController.map(_.getLatestRate())
// calculate a per-partition rate limit based on current lag
val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
case Some(rate) =>
// 计算消费延迟
val lagPerPartition = offsets.map { case (tp, offset) =>
tp -> Math.max(offset - currentOffsets(tp), 0)
}
// 总延迟
val totalLag = lagPerPartition.values.sum
lagPerPartition.map { case (tp, lag) =>
// 获取每个partition的最大消费限制
val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp)
//计算背压速度: 总延迟 分区延迟/总延迟 * 消费速度 = 每个分区应该消费的速度
val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
tp -> (if (maxRateLimitPerPartition > 0) {
Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
}
case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) }
}
if (effectiveRateLimitPerPartition.values.sum > 0) {
val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
Some(effectiveRateLimitPerPartition.map {
case (tp, limit) => tp -> (secsPerBatch * limit).toLong
})
} else {
None
}
}