使用Redis的Zset + lua实现滑动窗口限流

实现原理:

定义一个时间窗口,在这个时间窗口里面,对访问的次数做限制。这个窗口随着每次的访问是滑动的,主要是避免固定时间窗口中,访问集中在首尾造成接口访问超过限制。

这里是使用了Redis的Zset数据结构,利用他的score属性,通过每次范围跟时间戳来形成一个有序的队列,根据当前时间计算出时间窗口大小,来判断这个窗口内的访问次数有没有超过限制。

例如:现在限制接口5秒内最多只能访问5次


image.png

java代码:

首先定义一个限流注解:


@Target(ElementType.METHOD)

@Retention(RetentionPolicy.RUNTIME)

@Documented

public @interface RateLimiter {

    /**

    * 限流key

    */

    String key() default "rate:limiter:";

    /**

    * 单位时间限制通过请求数

    */

    long limit() default 1;

    /**

    * 过期时间,单位秒

    */

    long expire() default 5;

    /**

    * 限流提示语

    */

    String message() default "访问过于频繁";

    /**

    * 限流类型

    */

    LimitType limitType() default LimitType.DEFAULT;

}
限流类型:

public enum LimitType {

    /**

    * 默认

    */

    DEFAULT,

    /**

    * 根据IP限流

    */

    IP;

}

限流处理逻辑:


@Slf4j

@Aspect

@Component

public class RateLimiterHandler {

    @Autowired

    private RedisTemplate<String, Object> redisTemplate;

    private DefaultRedisScript<Long> getRedisScript;

    @PostConstruct

    public void init() {

        getRedisScript = new DefaultRedisScript<>();

        getRedisScript.setResultType(Long.class);

        getRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/rateLimiter.lua")));

        log.info("[分布式限流处理器]脚本加载完成");

    }

    @Around("@annotation(rateLimiter)")

    public Object around(ProceedingJoinPoint proceedingJoinPoint, RateLimiter rateLimiter) throws Throwable {

        log.debug("[分布式限流处理器]开始执行限流操作");

        // 限流模块key

        MethodSignature signature = (MethodSignature) proceedingJoinPoint.getSignature();

        Method method = signature.getMethod();

        ```

        StringBuilder limitKey = new StringBuilder(rateLimiter.key());

        if (rateLimiter.limitType() == LimitType.IP) {

            limitKey.append(IPUtil.getIpAddress());

        }

        // 目标类、方法

        String className = method.getDeclaringClass().getName();

        String name = method.getName();

        limitKey.append("_").append(className).append("_").append(name);

        // 限流阈值

        long limitCount = rateLimiter.limit();

        // 限流超时时间

        long expireTime = rateLimiter.expire();

        log.debug("[分布式限流处理器]参数值为:method={},limitKey={},limitCount={},limitTimeout={}", name, limitKey, limitCount, expireTime);

        // 执行Lua脚本

        List<String> keyList = new ArrayList<>();

        // 设置key值为注解中的值

        keyList.add(limitKey.toString());

        // 调用脚本并执行

        Long result = redisTemplate.execute(getRedisScript, keyList, expireTime,System.currentTimeMillis(), limitCount);

        log.debug("[分布式限流处理器]限流执行结果-result={}", result);

        if (null != result && result >= limitCount) {

            log.debug("由于超过单位时间={};允许的请求次数={}[触发限流]", expireTime, limitCount);

            // 限流提示语

            String message = rateLimiter.message();

            throw new BizException(message);

        }

        return proceedingJoinPoint.proceed();

    }

}

IPUtils工具类:


@Slf4j

public class IPUtil {

    public static String getIpAddress() {

        HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();

        String ip = null;

        String ipAddresses = request.getHeader("X-Forwarded-For");

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {

            ipAddresses = request.getHeader("Proxy-Client-IP");

        }

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {

            ipAddresses = request.getHeader("WL-Proxy-Client-IP");

        }

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {

            ipAddresses = request.getHeader("HTTP_CLIENT_IP");

        }

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {

            ipAddresses = request.getHeader("X-Real-IP");

        }

        if (ipAddresses != null && ipAddresses.length() != 0) {

            ip = ipAddresses.split(",")[0];

        }

        if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {

            ip = request.getRemoteAddr();

        }

        if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {

            ipAddresses = request.getRemoteAddr();

            if (ipAddresses.equals("127.0.0.1") || ipAddresses.equals("0:0:0:0:0:0:0:1")) {

                //根据网卡取本机配置的IP

                InetAddress inet = null;

                try {

                    inet = InetAddress.getLocalHost();

                } catch (UnknownHostException e) {

                    log.error("获取ip失败 {}", Arrays.asList(e.getStackTrace()));

                }

                if (inet != null) {

                    ip = inet.getHostAddress();

                } else {

                    ip = "127.0.0.1";

                }

            }

        }

        return ip;

    }

}

lua脚本代码:


--获取KEY

local key = KEYS[1]

--获取ARGV内的参数

-- 缓存时间

local expire = tonumber(ARGV[1])

-- 当前时间

local currentMs = tonumber(ARGV[2])

-- 最大次数

local count = tonumber(ARGV[3])

--窗口开始时间

local windowStartMs = currentMs - expire * 1000;

--获取key的次数

local current = redis.call('zcount', key, windowStartMs, currentMs)

--如果key的次数存在且大于预设值直接返回当前key的次数

if current and tonumber(current) >= count then

    return tonumber(current);

end

-- 清除所有过期成员

redis.call("ZREMRANGEBYSCORE", key, 0, windowStartMs);

-- 添加当前成员

redis.call("zadd", key, tostring(currentMs), currentMs);

redis.call("expire", key, expire);

--返回key的次数

return tonumber(current)

最后在想要限流的接口上添加注解就可以了:


@GetMapping("/test")

@RateLimiter(limit = 5, expire = 5, limitType = LimitType.IP)

public Result test() {

    return Result.success();

}
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容