新建一个自定义注解 AccessLimit
@Inherited
@Documented
@Target({ElementType.FIELD, ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface AccessLimit {
/**
* 指定second 时间内,最多的请求次数
*/
int count() default 5;
/**
* 指定时间second,redis数据过期时间
*/
int second() default 5;
}
新建一个请求接口限制次数拦截器
public class AccessLimitInterceptor implements HandlerInterceptor {
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
// handler是否为 HandleMethod 实例
if (handler instanceof HandlerMethod) {
// 强转
HandlerMethod handlerMethod = (HandlerMethod) handler;
// 获取方法
Method method = handlerMethod.getMethod();
// 判断方式是否有AccessLimit注解,有的才需要做限流
if (!method.isAnnotationPresent(AccessLimit.class)) {
return true;
}
// 获取注解上的内容
AccessLimit accessLimit = method.getAnnotation(AccessLimit.class);
if (accessLimit == null) {
return true;
}
// 获取方法注解上的请求次数
int count = accessLimit.count();
// 获取方法注解上的请求时间
Integer second = accessLimit.second();
// 获取ip地址
String ip = IPUtils.getIpAddress(request);
if(StringUtils.isEmpty(ip)){
throw new MyException("获取ip地址为空");
}
// 拼接redis key = IP + Api限流
String key = ip + request.getRequestURI();
// 获取redis的value
Integer maxTimes = null;
String value = RedisUtils.get(key);
if (!StringUtils.isEmpty(value)) {
maxTimes = Integer.valueOf(value);
}
if (maxTimes == null) {
// 如果redis中没有该ip对应的时间则表示第一次调用,保存key到redis
String one = "1";
RedisUtils.setex(key, second, one);
} else if (maxTimes < count) {
// 如果redis中的时间比注解上的时间小则表示可以允许访问,这是修改redis的value时间
RedisUtils.setex(key, second, String.valueOf(maxTimes + 1));
} else {
// 请求过于频繁
throw new MyException("系统繁忙,请稍后重试");
}
}
return true;
}
@Override
public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {
}
@Override
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
}
}
使用IP工具,获取真实ip地址
public class IPUtils {
public static String getIpAddress(HttpServletRequest request) {
String ip = request.getHeader("X-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip;
}
}
以上使用redis及异常类,可以自己换成你们已有的
测试类
@ApiImplicitParams({ @ApiImplicitParam(name = "pageNum", value = "当前页码", required = true, dataType = "int"),
@ApiImplicitParam(name = "pageSize", value = "每页显示的条数", required = true, dataType = "int") })
@ApiOperation(value = "分页查询接口入参", notes = "分页查询接口入参", httpMethod = "POST")
@PostMapping(value = "/queryPage")
@AccessLimit(count = 5, second = 10)
public Result<?> queryPage(@RequestBody UserInfo user, @RequestParam("pageNum") int pageNum, @RequestParam("pageSize") int pageSize) {
PageInfo pageInfo = userInfoService.queryPage(user, pageNum, pageSize);
return pageTool.getPageInfo(pageInfo);
}
以上参考自:
https://blog.csdn.net/zxl646801924/article/details/99442258