Redisson 限流源码学习

最近业务中用到了Redisson限流的功能,顺便研究一下底层实现
基于当前使用的版本<version>3.10.7</version>
目前用到的是accqure(),具体逻辑分析见代码中的注释

    @Override
    public void acquire(long permits) {
        // get 同步获取
        get(acquireAsync(permits));
    }

    // RFuture是继承jdk的Future类
    @Override
    public <V> V get(RFuture<V> future) {
        if (!future.isDone()) {
            CountDownLatch l = new CountDownLatch(1);
            future.onComplete((res, e) -> {
                l.countDown();
            });

            boolean interrupted = false;
            while (!future.isDone()) {
                try {
                    // future complete以后解除阻塞
                    l.await();
                } catch (InterruptedException e) {
                    interrupted = true;
                    break;
                }
            }

            if (interrupted) {
                Thread.currentThread().interrupt();
            }
        }

        // commented out due to blocking issues up to 200 ms per minute for each thread
        // future.awaitUninterruptibly();
        if (future.isSuccess()) {
            return future.getNow();
        }

        throw convertException(future);
    }

    @Override
    public RFuture<Void> acquireAsync(long permits) {
        RPromise<Void> promise = new RedissonPromise<Void>();
        // permits 代表要获取的许可数量,一般一次获取一个
        // -1代表不设置超时   可以看看tryAcquire带超时设置的重载方法了解此参数
        // null代表时间单位   没有设置时间所以单位为空
        tryAcquireAsync(permits, -1, null).onComplete((res, e) -> {
            if (e != null) {
                promise.tryFailure(e);
                return;
            }
            
            promise.trySuccess(null);
        });
        return promise;
    }

    @Override
    public RFuture<Boolean> tryAcquireAsync(long permits, long timeout, TimeUnit unit) {
        RPromise<Boolean> promise = new RedissonPromise<Boolean>();
        long timeoutInMillis = -1;
        // 如果有设置超时时间 转换为毫秒 调用真正的执行逻辑
        if (timeout > 0) {
            timeoutInMillis = unit.toMillis(timeout);
        }
        tryAcquireAsync(permits, promise, timeoutInMillis);
        return promise;
    }

    private void tryAcquireAsync(long permits, RPromise<Boolean> promise, long timeoutInMillis) {
        long s = System.currentTimeMillis();
        // 执行lua脚本,并返回Long类型。具体脚本内容往下看
        RFuture<Long> future = tryAcquireAsync(RedisCommands.EVAL_LONG, permits);
        // delay 代表lua脚本执行返回的值   e代表异常
        future.onComplete((delay, e) -> {
            if (e != null) {
                promise.tryFailure(e);
                return;
            }
            // 返回空就代表 获取许可成功了,为啥空代表成功需要看后面lua脚本
            if (delay == null) {
                // 给上层返回true
                promise.trySuccess(true);
                return;
            }
            // 走到这里表示获取许可失败了,但是获取许可失败了 要继续尝试
            // -1 代表不超时的逻辑
            // 获取许可失败的时候  返回的值赋给了delay  为啥取名delay因为返回的是等多久才能下一次获取
            if (timeoutInMillis == -1) {
                commandExecutor.getConnectionManager().getGroup().schedule(() -> {
                    tryAcquireAsync(permits, promise, timeoutInMillis);
                }, delay, TimeUnit.MILLISECONDS);
                return;
            }
            // 走到这里表示或许许可失败,但是设置了超时时间
            // 先看看已经花了多久
            long el = System.currentTimeMillis() - s;
            long remains = timeoutInMillis - el;
            // 超时了  给上层false
            if (remains <= 0) {
                promise.trySuccess(false);
                return;
            }
            // 暂时还没有超时,但是剩余的时间  比lua返回的要等待的时间还要短,那在超时时间内也不会成功
            // 给上层false
            if (remains < delay) {
                commandExecutor.getConnectionManager().getGroup().schedule(() -> {
                    promise.trySuccess(false);
                }, remains, TimeUnit.MILLISECONDS);
            } else {
                long start = System.currentTimeMillis();
                // 等待delay时间后再次尝试获取许可
                // 但是尝试之前再做一次超时判断
                commandExecutor.getConnectionManager().getGroup().schedule(() -> {
                    long elapsed = System.currentTimeMillis() - start;
                    if (remains <= elapsed) {
                        promise.trySuccess(false);
                        return;
                    }
                    // 递归调用
                    tryAcquireAsync(permits, promise, remains - elapsed);
                }, delay, TimeUnit.MILLISECONDS);
            }
        });
    }


    private <T> RFuture<T> tryAcquireAsync(RedisCommand<T> command, Long value) {
        // KEYS 有3个   ARGS只有1个
        return commandExecutor.evalWriteAsync(getName(), LongCodec.INSTANCE, command,
              // 在或许许可之前 校验有没有创建对应的 限流器基本信息  KEYS[1] 对应的是限流器的name
              // 往下查看trySetRate(创建限流器的时候会调用)可以看到是怎么设置进去的
              // rate是限速比如 每秒100 这里就是100  
              // interval是时间间隔 按每秒那这里就是1秒对应的毫秒数
              // type  0 全局限流  1 按客户端限流
                "local rate = redis.call('hget', KEYS[1], 'rate');"
              + "local interval = redis.call('hget', KEYS[1], 'interval');"
              + "local type = redis.call('hget', KEYS[1], 'type');"
              + "assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')"
              // valueName  即存放限流值对应的redis key
              + "local valueName = KEYS[2];"
              // 1 代表是按客户端分别限流   0 代表的是全局限流
              // 按客户端限流的话  redis key还要加上 客户端的id信息,具体看后面KEYS数组中第三个的值
              + "if type == '1' then "
                  + "valueName = KEYS[3];"
              + "end;"
              
              + "local currentValue = redis.call('get', valueName); "
              // 如果限流值存在
              + "if currentValue ~= false then "
                     // 比较当前值够不够要申请的许可数   不够说明达到限流上限了  然后返回ttl 也就是还有多久到期
                     + "if tonumber(currentValue) < tonumber(ARGV[1]) then "
                         + "return redis.call('pttl', valueName); "
                     + "else "
                     // 如果够,那就减去本次申请的许可数  然后返回空,空就代表成功了
                         + "redis.call('decrby', valueName, ARGV[1]); "
                         + "return nil; "
                     + "end; "
              + "else "
                     // 如果限流值在redis不存在  那说明是第一次申请许可或者又到了新的1秒 之前的过期了,所以要创建redis值
                     // 判断申请的许可数是否太大,比如每秒限流100  你传进来101  那肯定申请不下来
                     + "assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate'); "
                     // 初始化限流的值  并设置过期时间
                     + "redis.call('set', valueName, rate, 'px', interval); "
                     // 然后扣减本次申请的许可数
                     + "redis.call('decrby', valueName, ARGV[1]); "
                     + "return nil; "
              + "end;",
                Arrays.<Object>asList(getName(), getValueName(), getClientValueName()), 
                value, commandExecutor.getConnectionManager().getId().toString());
    }

    @Override
    public boolean trySetRate(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
        return get(trySetRateAsync(type, rate, rateInterval, unit));
    }

    @Override
    public RFuture<Boolean> trySetRateAsync(RateType type, long rate, long rateInterval, RateIntervalUnit unit) {
        return commandExecutor.evalWriteAsync(getName(), LongCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN,
                "redis.call('hsetnx', KEYS[1], 'rate', ARGV[1]);"
              + "redis.call('hsetnx', KEYS[1], 'interval', ARGV[2]);"
              + "return redis.call('hsetnx', KEYS[1], 'type', ARGV[3]);",
                Collections.<Object>singletonList(getName()), rate, unit.toMillis(rateInterval), type.ordinal());
    }

总结下来就是,往redis设置一个限流的数值,超时时间就是限流的时间区间
然后就去查询这个限流的数值,如果没有查到,肯定可以获取许可;如果查到了,要看看有没有超过许可数
如果获取许可成功了就返回nil,上层就知道成功;如果或许失败那就要等超时时间过去再请求,据返回一个delay时间。

但是,这个执行一次没有获取到许可的话,还要重试,所以上一层方法增加了重试的逻辑,重试是靠io.netty.util.concurrent.EventExecutorGroup#schedule(java.lang.Runnable, long, java.util.concurrent.TimeUnit)来实现,EventExecutorGroup实现了java.util.concurrent.ScheduledExecutorService。

如果我们熟悉Redis的话,应该会觉得限流可以用其他的数据结构来实现,比如zset,用zrangescore来获取某一时间窗口内的请求数,然后判断有没有达到限流阈值。

我们升级一下redisson,看看高版本有没有优化。
继续看看版本<version>3.34.0</version>

    private CompletableFuture<Boolean> tryAcquireAsync(long permits, long timeoutInMillis) {
        long s = System.currentTimeMillis();
        RFuture<Long> future = tryAcquireAsync(RedisCommands.EVAL_LONG, permits);
        return future.thenCompose(delay -> {
            if (delay == null) {
                return CompletableFuture.completedFuture(true);
            }
            
            if (timeoutInMillis == -1) {
                CompletableFuture<Boolean> f = new CompletableFuture<>();
                getServiceManager().newTimeout(t -> {
                    CompletableFuture<Boolean> r = tryAcquireAsync(permits, timeoutInMillis);
                    commandExecutor.transfer(r, f);
                }, delay, TimeUnit.MILLISECONDS);
                return f;
            }
            
            long el = System.currentTimeMillis() - s;
            long remains = timeoutInMillis - el;
            if (remains <= 0) {
                return CompletableFuture.completedFuture(false);
            }

            CompletableFuture<Boolean> f = new CompletableFuture<>();
            if (remains < delay) {
                getServiceManager().newTimeout(t -> {
                    f.complete(false);
                }, remains, TimeUnit.MILLISECONDS);
            } else {
                long start = System.currentTimeMillis();
                getServiceManager().newTimeout(t -> {
                    long elapsed = System.currentTimeMillis() - start;
                    if (remains <= elapsed) {
                        f.complete(false);
                        return;
                    }

                    CompletableFuture<Boolean> r = tryAcquireAsync(permits, remains - elapsed);
                    commandExecutor.transfer(r, f);
                }, delay, TimeUnit.MILLISECONDS);
            }
            return f;
        }).toCompletableFuture();
    }

    private <T> RFuture<T> tryAcquireAsync(RedisCommand<T> command, Long value) {
        byte[] random = getServiceManager().generateIdArray();

        return commandExecutor.evalWriteAsync(getRawName(), LongCodec.INSTANCE, command,
                "local rate = redis.call('hget', KEYS[1], 'rate');"
              + "local interval = redis.call('hget', KEYS[1], 'interval');"
              + "local type = redis.call('hget', KEYS[1], 'type');"
              + "assert(rate ~= false and interval ~= false and type ~= false, 'RateLimiter is not initialized')"
              
              + "local valueName = KEYS[2];"
              + "local permitsName = KEYS[4];"
              + "if type == '1' then "
                  + "valueName = KEYS[3];"
                  + "permitsName = KEYS[5];"
              + "end;"

              + "assert(tonumber(rate) >= tonumber(ARGV[1]), 'Requested permits amount could not exceed defined rate'); "
              // 存储当前剩余的许可数
              + "local currentValue = redis.call('get', valueName); "
              + "local res;"
              + "if currentValue ~= false then "
                     // 窗口滑动,查询已过期的许可
                     + "local expiredValues = redis.call('zrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); "
                     + "local released = 0; "
                     + "for i, v in ipairs(expiredValues) do "
                          + "local random, permits = struct.unpack('Bc0I', v);"
                          // 所有已过期的许可数加起来
                          + "released = released + permits;"
                     + "end; "

                     + "if released > 0 then "
                          // 删除已过期的许可  避免下一次range查询又查出来了
                          + "redis.call('zremrangebyscore', permitsName, 0, tonumber(ARGV[2]) - interval); "
                          // 如果当前剩余许可数 + 已过期 > 总限流数。这种一般不存在,除非重新设置了限流速率?
                         // 如果是这样的话,重新计算一下剩余许可数,用rate - 有效期内已经申请的许可数,这里不是100%准确   zcard是集合计数   一个条目申请的许可数可能是大于1的
                        // 当前剩余许可数 + 已过期 <= 总限流数   已过期的可以重新在新窗口使用
                          + "if tonumber(currentValue) + released > tonumber(rate) then "
                               + "currentValue = tonumber(rate) - redis.call('zcard', permitsName); "
                          + "else "
                               + "currentValue = tonumber(currentValue) + released; "
                          + "end; "
                          + "redis.call('set', valueName, currentValue);"
                     + "end;"

                     // 如果当前剩余许可数 小于 本次申请的许可数,则申请失败  最后会返回nil
                     + "if tonumber(currentValue) < tonumber(ARGV[1]) then "
                         + "local firstValue = redis.call('zrange', permitsName, 0, 0, 'withscores'); "
                         // 计算多长时间后可以重新获取许可  3是什么意思没有看懂
                         + "res = 3 + interval - (tonumber(ARGV[2]) - tonumber(firstValue[2]));"
                     + "else "
                         // 申请成功  zset写入许可申请记录
                         + "redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); "
                         + "redis.call('decrby', valueName, ARGV[1]); "
                         + "res = nil; "
                     + "end; "
              + "else "
                     // 初始化限流数量
                     + "redis.call('set', valueName, rate); "
                     // 记录本次申请的许可
                     + "redis.call('zadd', permitsName, ARGV[2], struct.pack('Bc0I', string.len(ARGV[3]), ARGV[3], ARGV[1])); "
                     // 剩余许可数量 = 限流数量 - 本次申请的许可数量
                     + "redis.call('decrby', valueName, ARGV[1]); "
                     + "res = nil; "
              + "end;"

              + "local ttl = redis.call('pttl', KEYS[1]); "
              + "if ttl > 0 then "
                  + "redis.call('pexpire', valueName, ttl); "
                  + "redis.call('pexpire', permitsName, ttl); "
              + "end; "
              + "return res;",
                Arrays.asList(getRawName(), getValueName(), getClientValueName(), getPermitsName(), getClientPermitsName()),
                value, System.currentTimeMillis(), random);
    }

新版本的设计思路完全改变了,用k/v存剩余许可数,zset存许可申请明细。
每一次许可申请使用zadd往zset增加一条,用毫秒时间戳做score 用struct.park 压缩成二进制存储申请的许可数。
每一次请求用滑动窗口动态判断当前剩余许可数。
如果达到限流上限,返回ttl,外层采用延迟重试的方式继续请求获取许可,思路和之前版本类似。不同的是,这里使用了netty时间轮
io.netty.util.HashedWheelTimer#newTimeout

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容