最近业务中用到了Redisson限流的功能,顺便研究一下底层实现
基于当前使用的版本<version>3.10.7</version>
目前用到的是accqure(),具体逻辑分析见代码中的注释
@Override
public void acquire(long permits) {
// get 同步获取
get(acquireAsync(permits));
}
// RFuture是继承jdk的Future类
@Override
public <V> V get(RFuture<V> future) {
if (!future.isDone()) {
CountDownLatch l = new CountDownLatch(1);
future.onComplete((res, e) -> {
l.countDown();
});
boolean interrupted = false;
while (!future.isDone()) {
try {
// future complete以后解除阻塞
l.await();
} catch (InterruptedException e) {
interrupted = true;
break;
}
}
if (interrupted) {
Thread.currentThread().interrupt();
}
}
// commented out due to blocking issues up to 200 ms per minute for each thread
// future.awaitUninterruptibly();
if (future.isSuccess()) {
return future.getNow();
}
throw convertException(future);
}
@Override
public RFuture<Void> acquireAsync(long permits) {
RPromise<Void> promise = new RedissonPromise<Void>();
// permits 代表要获取的许可数量,一般一次获取一个
// -1代表不设置超时 可以看看tryAcquire带超时设置的重载方法了解此参数
// null代表时间单位 没有设置时间所以单位为空
tryAcquireAsync(permits, -1, null).onComplete((res, e) -> {
if (e != null) {
promise.tryFailure(e);
return;
}
promise.trySuccess(null);
});
return promise;
}
@Override
public RFuture<Boolean> tryAcquireAsync(long permits, long timeout, TimeUnit unit) {
RPromise<Boolean> promise = new RedissonPromise<Boolean>();
long timeoutInMillis = -1;
// 如果有设置超时时间 转换为毫秒 调用真正的执行逻辑
if (timeout > 0) {
timeoutInMillis = unit.toMillis(timeout);
}
tryAcquireAsync(permits, promise, timeoutInMillis);
return promise;
}
private void tryAcquireAsync(long permits, RPromise<Boolean> promise, long timeoutInMillis) {
long s = System.currentTimeMillis();
// 执行lua脚本,并返回Long类型。具体脚本内容往下看
RFuture<Long> future = tryAcquireAsync(RedisCommands.EVAL_LONG, permits);
// delay 代表lua脚本执行返回的值 e代表异常
future.onComplete((delay, e) -> {
if (e != null) {
promise.tryFailure(e);
return;
}
// 返回空就代表 获取许可成功了,为啥空代表成功需要看后面lua脚本
if (delay == null) {
// 给上层返回true
promise.trySuccess(true);
return;
}
// 走到这里表示获取许可失败了,但是获取许可失败了 要继续尝试
// -1 代表不超时的逻辑
// 获取许可失败的时候 返回的值赋给了delay 为啥取名delay因为返回的是等多久才能下一次获取
if (timeoutInMillis == -1) {
commandExecutor.getConnectionManager().getGroup().schedule(() -> {
tryAcquireAsync(permits, promise, timeoutInMillis);
}, delay, TimeUnit.MILLISECONDS);
return;
}
// 走到这里表示或许许可失败,但是设置了超时时间
// 先看看已经花了多久
long el = System.currentTimeMillis() - s;
long remains = timeoutInMillis - el;
// 超时了 给上层false
if (remains <= 0) {
promise.trySuccess(false);
return;
}
// 暂时还没有超时,但是剩余的时间 比lua返回的要等待的时间还要短,那在超时时间内也不会成功
// 给上层false
if (remains < delay) {
commandExecutor.getConnectionManager().getGroup().schedule(() -> {
promise.trySuccess(false);
}, remains, TimeUnit.MILLISECONDS);
} else {
long start = System.currentTimeMillis();
// 等待delay时间后再次尝试获取许可
// 但是尝试之前再做一次超时判断
commandExecutor.getConnectionManager().getGroup().schedule(() -> {
long elapsed = System.currentTimeMillis() - start;
if (remains <= elapsed) {
promise.trySuccess(false);
return;
}
// 递归调用
tryAcquireAsync(permits, promise, remains - elapsed);
}, delay, TimeUnit.MILLISECONDS);
}
});
}
private <T> RFuture<T> tryAcquireAsync(RedisCommand<T> command, Long value) {
// KEYS 有3个 ARGS只有1个
return commandExecutor.evalWriteAsync(getName(), LongCodec.INSTANCE, command,
// 在或许许可之前 校验有没有创建对应的 限流器基本信息 KEYS[1] 对应的是限流器的name
// 往下查看trySetRate(创建限流器的时候会调用)可以看到是怎么设置进去的
// rate是限速比如 每秒100 这里就是100
// interval是时间间隔 按每秒那这里就是1秒对应的毫秒数
// type 0 全局限流 1 按客户端限流
"local rate = redis.call('hget', KEYS[1], 'rate');"
+ "local interval = redis.call('hget', KEYS[1], 'interval');"
+ "local type = redis.call('hget', KEYS[1], 'type');"
+ "assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')"
// valueName 即存放限流值对应的redis key
+ "local valueName = KEYS[2];"
// 1 代表是按客户端分别限流 0 代表的是全局限流
// 按客户端限流的话 redis key还要加上 客户端的id信息,具体看后面KEYS数组中第三个的值
+ "if type == '1' then "
+ "valueName = KEYS[3];"
+ "end;"
+ "local currentValue = redis.call('get', valueName); "
// 如果限流值存在
+ "if currentValue ~= false then "
// 比较当前值够不够要申请的许可数 不够说明达到限流上限了 然后返回ttl 也就是还有多久到期
+ "if tonumber(currentValue) < tonumber(ARGV[1]) then "
+ "return redis.call('pttl', valueName); "
+ "else "
// 如果够,那就减去本次申请的许可数 然后返回空,空就代表成功了
+ "redis.call('decrby', valueName, ARGV[1]); "
+ "return nil; "
+ "end; "
+ "else "
// 如果限流值在redis不存在 那说明是第一次申请许可或者又到了新的1秒 之前的过期了,所以要创建redis值
// 判断申请的许可数是否太大,比如每秒限流100 你传进来101 那肯定申请不下来
+ "assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate'); "
// 初始化限流的值 并设置过期时间
+ "redis.call('set', valueName, rate, 'px', interval); "
// 然后扣减本次申请的许可数
+ "redis.call('decrby', valueName, ARGV[1]); "
+ "return nil; "
+ "end;",
Arrays.<Object>asList(getName(), getValueName(), getClientValueName()),
value, commandExecutor.getConnectionManager().getId().toString());
}
@Override
public boolean trySetRate(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
return get(trySetRateAsync(type, rate, rateInterval, unit));
}
@Override
public RFuture<Boolean> trySetRateAsync(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
return commandExecutor.evalWriteAsync(getName(), LongCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN,
"redis.call('hsetnx', KEYS[1], 'rate', ARGV[1]);"
+ "redis.call('hsetnx', KEYS[1], 'interval', ARGV[2]);"
+ "return redis.call('hsetnx', KEYS[1], 'type', ARGV[3]);",
Collections.<Object>singletonList(getName()), rate, unit.toMillis(rateInterval), type.ordinal());
}
总结下来就是,往redis设置一个限流的数值,超时时间就是限流的时间区间
然后就去查询这个限流的数值,如果没有查到,肯定可以获取许可;如果查到了,要看看有没有超过许可数
如果获取许可成功了就返回nil,上层就知道成功;如果或许失败那就要等超时时间过去再请求,据返回一个delay时间。
但是,这个执行一次没有获取到许可的话,还要重试,所以上一层方法增加了重试的逻辑,重试是靠io.netty.util.concurrent.EventExecutorGroup#schedule(java.lang.Runnable, long, java.util.concurrent.TimeUnit)来实现,EventExecutorGroup实现了java.util.concurrent.ScheduledExecutorService。
如果我们熟悉Redis的话,应该会觉得限流可以用其他的数据结构来实现,比如zset,用zrangescore来获取某一时间窗口内的请求数,然后判断有没有达到限流阈值。
我们升级一下redisson,看看高版本有没有优化。
继续看看版本<version>3.34.0</version>
private CompletableFuture<Boolean> tryAcquireAsync(long permits, long timeoutInMillis) {
long s = System.currentTimeMillis();
RFuture<Long> future = tryAcquireAsync(RedisCommands.EVAL_LONG, permits);
return future.thenCompose(delay -> {
if (delay == null) {
return CompletableFuture.completedFuture(true);
}
if (timeoutInMillis == -1) {
CompletableFuture<Boolean> f = new CompletableFuture<>();
getServiceManager().newTimeout(t -> {
CompletableFuture<Boolean> r = tryAcquireAsync(permits, timeoutInMillis);
commandExecutor.transfer(r, f);
}, delay, TimeUnit.MILLISECONDS);
return f;
}
long el = System.currentTimeMillis() - s;
long remains = timeoutInMillis - el;
if (remains <= 0) {
return CompletableFuture.completedFuture(false);
}
CompletableFuture<Boolean> f = new CompletableFuture<>();
if (remains < delay) {
getServiceManager().newTimeout(t -> {
f.complete(false);
}, remains, TimeUnit.MILLISECONDS);
} else {
long start = System.currentTimeMillis();
getServiceManager().newTimeout(t -> {
long elapsed = System.currentTimeMillis() - start;
if (remains <= elapsed) {
f.complete(false);
return;
}
CompletableFuture<Boolean> r = tryAcquireAsync(permits, remains - elapsed);
commandExecutor.transfer(r, f);
}, delay, TimeUnit.MILLISECONDS);
}
return f;
}).toCompletableFuture();
}
private <T> RFuture<T> tryAcquireAsync(RedisCommand<T> command, Long value) {
byte[] random = getServiceManager().generateIdArray();
return commandExecutor.evalWriteAsync(getRawName(), LongCodec.INSTANCE, command,
"local rate = redis.call('hget', KEYS[1], 'rate');"
+ "local interval = redis.call('hget', KEYS[1], 'interval');"
+ "local type = redis.call('hget', KEYS[1], 'type');"
+ "assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')"
+ "local valueName = KEYS[2];"
+ "local permitsName = KEYS[4];"
+ "if type == '1' then "
+ "valueName = KEYS[3];"
+ "permitsName = KEYS[5];"
+ "end;"
+ "assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate'); "
// 存储当前剩余的许可数
+ "local currentValue = redis.call('get', valueName); "
+ "local res;"
+ "if currentValue ~= false then "
// 窗口滑动,查询已过期的许可
+ "local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); "
+ "local released = 0; "
+ "for i, v in ipairs(expiredValues) do "
+ "local random, permits = struct.unpack('Bc0I', v);"
// 所有已过期的许可数加起来
+ "released = released + permits;"
+ "end; "
+ "if released > 0 then "
// 删除已过期的许可 避免下一次range查询又查出来了
+ "redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); "
// 如果当前剩余许可数 + 已过期 > 总限流数。这种一般不存在,除非重新设置了限流速率?
// 如果是这样的话,重新计算一下剩余许可数,用rate - 有效期内已经申请的许可数,这里不是100%准确 zcard是集合计数 一个条目申请的许可数可能是大于1的
// 当前剩余许可数 + 已过期 <= 总限流数 已过期的可以重新在新窗口使用
+ "if tonumber(currentValue) + released > tonumber(rate) then "
+ "currentValue = tonumber(rate) - redis.call('zcard', permitsName); "
+ "else "
+ "currentValue = tonumber(currentValue) + released; "
+ "end; "
+ "redis.call('set', valueName, currentValue);"
+ "end;"
// 如果当前剩余许可数 小于 本次申请的许可数,则申请失败 最后会返回nil
+ "if tonumber(currentValue) < tonumber(ARGV[1]) then "
+ "local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores'); "
// 计算多长时间后可以重新获取许可 3是什么意思没有看懂
+ "res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));"
+ "else "
// 申请成功 zset写入许可申请记录
+ "redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); "
+ "redis.call('decrby', valueName, ARGV[1]); "
+ "res = nil; "
+ "end; "
+ "else "
// 初始化限流数量
+ "redis.call('set', valueName, rate); "
// 记录本次申请的许可
+ "redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); "
// 剩余许可数量 = 限流数量 - 本次申请的许可数量
+ "redis.call('decrby', valueName, ARGV[1]); "
+ "res = nil; "
+ "end;"
+ "local ttl = redis.call('pttl', KEYS[1]); "
+ "if ttl > 0 then "
+ "redis.call('pexpire', valueName, ttl); "
+ "redis.call('pexpire', permitsName, ttl); "
+ "end; "
+ "return res;",
Arrays.asList(getRawName(), getValueName(), getClientValueName(), getPermitsName(), getClientPermitsName()),
value, System.currentTimeMillis(), random);
}
新版本的设计思路完全改变了,用k/v存剩余许可数,zset存许可申请明细。
每一次许可申请使用zadd往zset增加一条,用毫秒时间戳做score 用struct.park 压缩成二进制存储申请的许可数。
每一次请求用滑动窗口动态判断当前剩余许可数。
如果达到限流上限,返回ttl,外层采用延迟重试的方式继续请求获取许可,思路和之前版本类似。不同的是,这里使用了netty时间轮
io.netty.util.HashedWheelTimer#newTimeout