访问量的控制比较常见,对外提供的服务,有的需要控制QPS,防止服务宕机;有的需要控制一个时间段的访问数量。
本文基于springboot,采用切面+redis的方式实现
- 在需要进行访问量控制的地方加入注解;
- 在注解操作中,获取当前访问的ip地址,利用redis做计数,超过limit则报错;
- 问题的关键在于:在分布式环境下,对redis的操作可能会出现竞争,所以要把对redis的操作使用lua脚本,这样所有的操作是原子性的。
自定义注解:
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import java.lang.annotation.*;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Documented
@Order(Ordered.HIGHEST_PRECEDENCE)
public @interface RequestLimit {
/**
* 允许访问的最大次数
*/
int count() default Integer.MAX_VALUE;
/**
* 时间段,单位为毫秒,默认值一分钟
*/
long time() default 60000;
}
切面操作:
import com.example.common.Constants;
import com.example.common.ErrorCode;
import com.example.exception.BusinessException;
import com.example.utils.IpUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
@Aspect
@Component
public class RequestLimitAspect {
@Autowired
private DefaultRedisScript<Boolean> redisScript;
@Autowired
private StringRedisTemplate stringRedisTemplate;
@Pointcut("@annotation(com.example.aspect.RequestLimit)")
public void pointcut() {
}
@Before("pointcut() && @annotation(requestLimit)")
public void doBefore(JoinPoint joinPoint, RequestLimit requestLimit) {
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (null == requestAttributes) {
return;
}
HttpServletRequest httpRequest = requestAttributes.getRequest();
String ip = IpUtils.getRealIP(httpRequest);
String key = Constants.KEY_PREFIX + ip;
Boolean allow = stringRedisTemplate.execute(
redisScript,
Collections.singletonList(key),
String.valueOf(requestLimit.count()), //limit
String.valueOf(requestLimit.time())); //expire
assert allow != null;
if (!allow) {
throw new BusinessException(ErrorCode.REQUEST_EXCEED_LIMIT);
}
return;
}
}
其中对redis的操作用了一个配置类
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.scripting.support.ResourceScriptSource;
@Configuration
public class LuaRedisConfiguration {
@Bean
public DefaultRedisScript<Boolean> redisScript() {
DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/requestLimit.lua")));
redisScript.setResultType(Boolean.class);
return redisScript;
}
}
lua脚本:
local key = KEYS[1]
local value = 1
local limit = tonumber(ARGV[1])
local expire = ARGV[2]
if redis.call("SET", key, value, "NX", "PX", expire) then
return 1
else
if redis.call("INCR", key) <= limit then
return 1
end
if redis.call("TTL", key) == -1 then
redis.call("PEXPIRE", key, expire)
end
end
return 0
参考文章:http://www.genxiaogu.com/Springboot-%E9%9B%86%E7%BE%A4QPS%E6%8E%A7%E5%88%B6starter/