实现原理:
定义一个时间窗口,在这个时间窗口里面,对访问的次数做限制。这个窗口随着每次的访问是滑动的,主要是避免固定时间窗口中,访问集中在首尾造成接口访问超过限制。
这里是使用了Redis的Zset数据结构,利用他的score属性,通过每次范围跟时间戳来形成一个有序的队列,根据当前时间计算出时间窗口大小,来判断这个窗口内的访问次数有没有超过限制。
例如:现在限制接口5秒内最多只能访问5次

image.png
java代码:
首先定义一个限流注解:
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流key
*/
String key() default "rate:limiter:";
/**
* 单位时间限制通过请求数
*/
long limit() default 1;
/**
* 过期时间,单位秒
*/
long expire() default 5;
/**
* 限流提示语
*/
String message() default "访问过于频繁";
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
}
限流类型:
public enum LimitType {
/**
* 默认
*/
DEFAULT,
/**
* 根据IP限流
*/
IP;
}
限流处理逻辑:
@Slf4j
@Aspect
@Component
public class RateLimiterHandler {
@Autowired
private RedisTemplate<String, Object> redisTemplate;
private DefaultRedisScript<Long> getRedisScript;
@PostConstruct
public void init() {
getRedisScript = new DefaultRedisScript<>();
getRedisScript.setResultType(Long.class);
getRedisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/rateLimiter.lua")));
log.info("[分布式限流处理器]脚本加载完成");
}
@Around("@annotation(rateLimiter)")
public Object around(ProceedingJoinPoint proceedingJoinPoint, RateLimiter rateLimiter) throws Throwable {
log.debug("[分布式限流处理器]开始执行限流操作");
// 限流模块key
MethodSignature signature = (MethodSignature) proceedingJoinPoint.getSignature();
Method method = signature.getMethod();
```
StringBuilder limitKey = new StringBuilder(rateLimiter.key());
if (rateLimiter.limitType() == LimitType.IP) {
limitKey.append(IPUtil.getIpAddress());
}
// 目标类、方法
String className = method.getDeclaringClass().getName();
String name = method.getName();
limitKey.append("_").append(className).append("_").append(name);
// 限流阈值
long limitCount = rateLimiter.limit();
// 限流超时时间
long expireTime = rateLimiter.expire();
log.debug("[分布式限流处理器]参数值为:method={},limitKey={},limitCount={},limitTimeout={}", name, limitKey, limitCount, expireTime);
// 执行Lua脚本
List<String> keyList = new ArrayList<>();
// 设置key值为注解中的值
keyList.add(limitKey.toString());
// 调用脚本并执行
Long result = redisTemplate.execute(getRedisScript, keyList, expireTime,System.currentTimeMillis(), limitCount);
log.debug("[分布式限流处理器]限流执行结果-result={}", result);
if (null != result && result >= limitCount) {
log.debug("由于超过单位时间={};允许的请求次数={}[触发限流]", expireTime, limitCount);
// 限流提示语
String message = rateLimiter.message();
throw new BizException(message);
}
return proceedingJoinPoint.proceed();
}
}
IPUtils工具类:
@Slf4j
public class IPUtil {
public static String getIpAddress() {
HttpServletRequest request = ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
String ip = null;
String ipAddresses = request.getHeader("X-Forwarded-For");
if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
ipAddresses = request.getHeader("Proxy-Client-IP");
}
if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
ipAddresses = request.getHeader("WL-Proxy-Client-IP");
}
if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
ipAddresses = request.getHeader("HTTP_CLIENT_IP");
}
if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
ipAddresses = request.getHeader("X-Real-IP");
}
if (ipAddresses != null && ipAddresses.length() != 0) {
ip = ipAddresses.split(",")[0];
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
ip = request.getRemoteAddr();
}
if (ipAddresses == null || ipAddresses.length() == 0 || "unknown".equalsIgnoreCase(ipAddresses)) {
ipAddresses = request.getRemoteAddr();
if (ipAddresses.equals("127.0.0.1") || ipAddresses.equals("0:0:0:0:0:0:0:1")) {
//根据网卡取本机配置的IP
InetAddress inet = null;
try {
inet = InetAddress.getLocalHost();
} catch (UnknownHostException e) {
log.error("获取ip失败 {}", Arrays.asList(e.getStackTrace()));
}
if (inet != null) {
ip = inet.getHostAddress();
} else {
ip = "127.0.0.1";
}
}
}
return ip;
}
}
lua脚本代码:
--获取KEY
local key = KEYS[1]
--获取ARGV内的参数
-- 缓存时间
local expire = tonumber(ARGV[1])
-- 当前时间
local currentMs = tonumber(ARGV[2])
-- 最大次数
local count = tonumber(ARGV[3])
--窗口开始时间
local windowStartMs = currentMs - expire * 1000;
--获取key的次数
local current = redis.call('zcount', key, windowStartMs, currentMs)
--如果key的次数存在且大于预设值直接返回当前key的次数
if current and tonumber(current) >= count then
return tonumber(current);
end
-- 清除所有过期成员
redis.call("ZREMRANGEBYSCORE", key, 0, windowStartMs);
-- 添加当前成员
redis.call("zadd", key, tostring(currentMs), currentMs);
redis.call("expire", key, expire);
--返回key的次数
return tonumber(current)
最后在想要限流的接口上添加注解就可以了:
@GetMapping("/test")
@RateLimiter(limit = 5, expire = 5, limitType = LimitType.IP)
public Result test() {
return Result.success();
}