在开发后端接口时,对于一些非对外开放的接口,我们总想做得简单一些,不去过多考虑类似DDOS之类的恶意攻击,但当上线之后却发现,时常出现一些奇奇怪怪的问题困扰着后端的同学,其中就包括因用户短时间内对按钮的多次点击导致接口频繁接收到相同参数的请求,进而导致一些令人脑壳疼的数据问题。当然,有经验的同学会说,前端加个防抖动机制就可以了不是吗?没错,是可以的,只是本文不探讨前端的解决方案,而是探讨后端能够实现的解决方案(其实是困于团队没多少前端资源,才会苦尽心思从后端下手T_T)。话说回来,这些由于用户多次点击产生的数据问题,有些在库里面加个唯一约束键就可以解决,有些却需要在接口代码增加一个互斥的逻辑(即一个用户在某一时刻,只能有一个线程能够执行某个接口代码逻辑)。这个互斥的逻辑在分布式的环境中我们一般需要用分布式锁来实现,实现方式有mysql的悲观锁、redis或zookeeper实现的分布式锁等,本文不探讨分布式锁的实现方案,而是在分布式锁的基础上,讲讲如何简单地封装一个我称之为“请求锁”的小工具,方便后端同学开箱即用。
废话不多说,直接上代码:
ApiLock注解
这个注解作用于方法上,用来标记某个接口需要接入请求锁。既然是封装了分布式锁,那么分布式锁的lockKey如何生成?此处默认生成策略为:直接根据所有的请求参数组装后转化成jsonStr,再经过md5得到lockKey,除此之外在注解中支持传入指定的lockKey生成策略类来生成。当然,这种实现方式对一些接口并不适用,例如上传图片、文件的接口,因为这些接口接收的参数对象比较特殊,例如:spring web包下的MultipartFile对象,还有其它一些IO流式的数据对象,这些目前未做兼容(实际上也少有对上传类型的接口做用户防抖的需求)。
注解主要属性有几个:
- waitMills:获取锁等待时间(毫秒)
- expireMills:锁自动过期时间(毫秒)
- requiredRequestAttrs:声明当前接口需从HttpServletRequest上下文中获取的参数(request.getAttribute(attr)),该参数只支持从方法参数HttpServletRequest中获取,如果填入该值则会获取并参与lock key的拼接。
-
requiredHeaders:声明当前接口需从请求头中获取的参数(request.getHeader(header))
,该参数只支持从方法参数HttpServletRequest中获取,如果填入该值则会获取并参与lock key的拼接。 - lockKeyGenerateStrategy:指定lockKey生成策略类型。默认用ApiLockKeyDefaultGenerateStrategy进行生成,可实现策略接口ApiLockKeyGenerateStrategy重写generateKey(Map<String, Object> params)方法自定义生成策略。
package com.xx.api.app.annotation;
import com.xx.api.app.aop.ApiLockKeyGenerateStrategy;
import com.xx.api.app.aop.ApiLockKeyDefaultGenerateStrategy;
import com.xx.starter.plugin.plugins.DistributedLock;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* @author: xiebochang
* @Title: 用来标记接口需要接入请求锁
* (需确保已引入分布式锁的实现)
*
* @Desc
* lockKey默认为所有请求参数jsonStr的md5值(可以根据lockKeyGenerator更改生成策略),如果最终得到的参数为空,则加锁无效
* 下面几种情况会导致参数为空、加锁无效:
* 1、方法签名未申明任何请求参数
* 2、方法申明了参数,但客户端传过来的所有参数均为空值
* 3、方法参数只有一个HttpServletRequest,但未申明{@link #requiredRequestAttrs}、{@link #requiredHeaders()}的其中一个,换句话说还是没有参数作为lockKey
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ApiLock {
/**
* 获取锁等待时间(毫秒)
*/
long waitMills();
/**
* 锁自动过期时间(毫秒)
*/
long expireMills();
/**
* 声明当前接口需从HttpServletRequest上下文中获取的参数(request.getAttribute(attr))
* (只支持从方法参数HttpServletRequest中获取,如果填入该值则会获取并参与lock key的拼接)
*/
String[] requiredRequestAttrs() default {};
/**
* 声明当前接口需从请求头中获取的参数(request.getHeader(header))
* (只支持从方法参数HttpServletRequest中获取,如果填入该值则会获取并参与lock key的拼接)
*/
String[] requiredHeaders() default {};
/**
* 指定lockKey生成器类型,默认用ApiLockKeyDefaultGenerateStrategy,可自定义实现
*/
Class<? extends ApiLockKeyGenerateStrategy> lockKeyGenerateStrategy() default ApiLockKeyDefaultGenerateStrategy.class;
}
lockKey生成策略接口
实现该接口并重写generateKey方法可以自定义lockKey的生成策略
package com.xx.api.app.aop;
import java.util.Map;
/**
* lockKey生成策略接口
*/
public interface ApiLockKeyGenerateStrategy {
String generateKey(String prefix, Map<String, Object> params);
}
lockKey默认的生成策略
默认采用直接md5的方式得到lockKey
package com.xx.api.app.aop;
import com.alibaba.fastjson.JSON;
import com.xx.common.utils.MD5Util;
import java.util.Map;
/**
* lockKey默认生成策略
*/
public class ApiLockKeyDefaultGenerateStrategy implements ApiLockKeyGenerateStrategy {
@Override
public String generateKey(String prefix, Map<String, Object> methodNotNullArgsMap) {
// 拿到所有参数后,lockKey的生成逻辑可以自定义实现
return prefix + MD5Util.getMD5Str(JSON.toJSONString(methodNotNullArgsMap));
}
}
切面逻辑处理类
切面逻辑中主要切的是所有controller包下的类方法,重点关注方法为around(ProceedingJoinPoint thisJoinPoint)
,逻辑主要分为四步:
- 第一步:判断当前方法是否用
@ApiLock
注解修饰,不是则跳过; - 第二步:如果用
@ApiLock
修饰了,那么获取方法中所有非空的参数,如果参数为空值或NULL则不处理(包括一些空字符串、空对象、空数组都会被过滤掉),此处对空值的过滤是为了防止传参问题导致锁升级为全局锁。比如说一个接口中只需要传一个userid
,且参数又非必传,那么对于未传userid
的请求,这把请求锁就不会生效,否则会导致所有未传参的客户端去竞争同一把锁(md5出来是同一个lockKey),从而导致接口调用被阻塞,换句话说,如果接口允许不传参,那么不应该加这个@ApiLock
注解。参数判空的代码封装在静态内部类RemoveNullEntryUtil
中,用到递归来对可能出现的多层参数进行解析判断,这部分可能有点绕,但思来想去暂时没想出更好的写法; - 第三步:如果参数不为空,则根据指定的lockKey生成策略来生成lockKey;
- 第四步:得到lockKey后,就可以对当前请求进行加锁并执行接口逻辑,而后在
finally
中完成解锁操作,至此就完成所有切面处理的逻辑。需要注意的是,执行切面方法时并非catch异常进行处理,而是直接向外抛出,这是因为外层已定义了全局异常处理器对抛出异常进行处理。
package com.xx.api.app.aop;
import com.xx.api.app.annotation.ApiLock;
import com.xx.common.statics.exceptions.BizException;
import com.xx.common.utils.StringUtil;
import com.xx.starter.plugin.plugins.DistributedLock;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.stereotype.Component;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author: xiebochang
* @Title: ApiLockAspectPoint
* @desc: ApiLock注解切面处理
*/
@Aspect
@Component
public class ApiLockAspectPoint {
private static final Logger LOG = LoggerFactory.getLogger(ApiLockAspectPoint.class);
private static final String EXCLUDE_JAVA_BEAN_FILED = "serialVersionUID";
private Map<Class<? extends ApiLockKeyGenerateStrategy>, ApiLockKeyGenerateStrategy> apiLockKeyGenerateStrategyMap = new ConcurrentHashMap<>();
@Autowired
private DistributedLock redisLock;
// 切所有controller包下的类方法
@Pointcut("execution(* com.xx..*.controller..*.*(..))")
public void execute(){}
@Around("execute()")
public Object around(ProceedingJoinPoint thisJoinPoint) throws Throwable {
// 拿方法上注解,为空则不处理
MethodSignature signature = (MethodSignature) thisJoinPoint.getSignature();
ApiLock annotation = signature.getMethod().getAnnotation(ApiLock.class);
if (annotation == null) {
return thisJoinPoint.proceed();
}
// 拿方法中所有非空的参数,如果参数为空值或NULL则不处理
String methodName = thisJoinPoint.getSignature().getName();
Map<String, Object> methodNotNullArgsMap = getMethodNotNullArgsMap(thisJoinPoint, annotation);
if (null == methodNotNullArgsMap || methodNotNullArgsMap.isEmpty()) {
LOG.error("[{}] ==> there are no non-null arguments!", methodName);
return thisJoinPoint.proceed();
}
// 根据指定策略拿lockKey
Class<? extends ApiLockKeyGenerateStrategy> handlerClass = annotation.lockKeyGenerateStrategy();
ApiLockKeyGenerateStrategy keyGenerator;
if (apiLockKeyGenerateStrategyMap.containsKey(handlerClass)) {
keyGenerator = apiLockKeyGenerateStrategyMap.get(handlerClass);
} else {
keyGenerator = handlerClass.newInstance();
apiLockKeyGenerateStrategyMap.put(handlerClass, keyGenerator);
}
String lockKey = keyGenerator.generateKey(getDefaultPrefix(thisJoinPoint), methodNotNullArgsMap);
// 加锁、执行方法
Object lock = null;
try {
LOG.info("[{}]==> tryApiLock", methodName);
lock = redisLock.tryLock(lockKey, annotation.waitMills(), annotation.expireMills());
if (lock == null) {
throw new BizException("操作过于频繁");
}
LOG.info("[{}]==> api lock process start", methodName);
return thisJoinPoint.proceed();
} finally {
if (lock != null) {
redisLock.unlock(lock);
LOG.info("[{}]==> api lock process completed", methodName);
}
}
}
private String getDefaultPrefix(ProceedingJoinPoint thisJoinPoint) {
return thisJoinPoint.getTarget().getClass().getName() + "#" + thisJoinPoint.getSignature().getName() + "-";
}
/**
* 获取方法参数列表
*
* @param joinPoint
* @param annotation
* @return
* @throws ClassNotFoundException
* @throws NoSuchMethodException
*/
private Map<String, Object> getMethodNotNullArgsMap(ProceedingJoinPoint joinPoint, ApiLock annotation) {
Object[] args = joinPoint.getArgs();
ParameterNameDiscoverer pnd = new DefaultParameterNameDiscoverer();
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
String[] parameterNames = pnd.getParameterNames(method);
if (parameterNames == null) {
return null;
}
// 将所有参数丢到map中
Map<String, Object> resultParamMap = new HashMap<>();
for (int i = 0; i < parameterNames.length; i++) {
Object obj = args[i];
if (obj == null) {
continue;
}
// 如果参数类型是HttpServletRequest,判断一下是否需获取attr、header
if (obj instanceof HttpServletRequest) {
HttpServletRequest request = (HttpServletRequest) obj;
checkAndFillReqAttrIfNecessary(resultParamMap, annotation, request);
checkAndFillHeaderParamsIfNecessary(resultParamMap, annotation, request);
} else {
resultParamMap.put(parameterNames[i], obj);
}
}
if (resultParamMap.isEmpty()) {
return null;
}
// 考虑参数可能封装多层的情况,以及其中可能出现Entry的key或value为空值,需要remove掉这些Entry
// 防止空参数导致生成重复的key,从而导致升级成一把全局锁
RemoveNullEntryUtil.removeNullEntry(resultParamMap);
return resultParamMap;
}
private void checkAndFillReqAttrIfNecessary(Map<String, Object> paramMap, ApiLock annotation, HttpServletRequest request) {
if (annotation.requiredRequestAttrs() != null) {
String[] requestAttrs = annotation.requiredRequestAttrs();
for (String k : requestAttrs) {
Object val = request.getAttribute(k);
if (null == val) {
LOG.error("==> request attr:{} not exist", k);
continue;
}
paramMap.put(k, val);
}
}
}
private void checkAndFillHeaderParamsIfNecessary(Map<String, Object> paramMap, ApiLock annotation, HttpServletRequest request) {
if (annotation.requiredHeaders() != null) {
String[] headers = annotation.requiredHeaders();
for (String k : headers) {
String h = request.getHeader(k);
if (null == h || "".equals(h)) {
LOG.error("==> header:{} not exist", k);
continue;
}
paramMap.put(k, h);
}
}
}
static class RemoveNullEntryUtil {
/**
* 移除map中空key或者value空值
* @param map
*/
private static void removeNullEntry(Map<String, Object> map){
removeNullOrEmptyKey(map);
removeNullOrEmptyValue(map);
}
/**
* 移除map的key为空值的entry
* @param map
* @return
*/
private static void removeNullOrEmptyKey(Map<String, Object> map){
Set<String> set = map.keySet();
Iterator<String> iterator = set.iterator();
while (iterator.hasNext()) {
Object obj = iterator.next();
if (isObjectNullOrEmpty(obj)) {
iterator.remove();
}
}
}
/**
* 移除map中的value为空的entry
* @param map
* @return
*/
private static void removeNullOrEmptyValue(Map<String, Object> map){
Set<String> set = map.keySet();
Iterator iterator = set.iterator();
while (iterator.hasNext()) {
Object obj = iterator.next();
Object value = map.get(obj);
if (isObjectNullOrEmpty(value)) {
iterator.remove();
}
}
}
private static boolean isObjectNullOrEmpty(Object obj) {
// 参考org.springframework.util.ObjectUtils.isEmpty(java.lang.Object)
if(obj == null){
return true;
}
if (isPrimitive(obj.getClass())) {
return false;
}
if (obj instanceof Optional) {
return !((Optional) obj).isPresent();
}
if (obj instanceof CharSequence) {
return ((CharSequence) obj).length() == 0;
}
if (obj instanceof Collection) {
return ((Collection) obj).isEmpty();
}
if (obj instanceof Map) {
return ((Map) obj).isEmpty();
}
if (obj.getClass().isArray()) {
return Array.getLength(obj) == 0;
}
if (obj.getClass().getPackage().getName().startsWith("java.math")) {
// 能转换为math包下的类对象说明有值
return false;
}
return isEmptyJavaBean(obj);
}
private static boolean isPrimitive(Class<?> clazz) {
try {
if (clazz.isPrimitive()) {
return true;
}
return ((Class<?>) clazz.getField("TYPE").get(null)).isPrimitive();
} catch (IllegalArgumentException | IllegalAccessException | NoSuchFieldException | SecurityException e) {
return false;
}
}
private static boolean isEmptyJavaBean(Object object) {
Class clazz = object.getClass();
Field fields[] = clazz.getDeclaredFields();
boolean flag = true;
for(Field field : fields){
boolean hasChangeAccessFlag = false;
if (!field.isAccessible()) {
field.setAccessible(true);
hasChangeAccessFlag = true;
}
Object fieldValue = null;
String fieldName = field.getName();
if (EXCLUDE_JAVA_BEAN_FILED.equals(fieldName)) {
// 忽略序列号字段
continue;
}
try {
fieldValue = field.get(object);
} catch (IllegalAccessException e) {
// 实际上前面已经确保了属性可以访问,所以不会抛该异常
LOG.error("==> get field err:", e);
}
if (hasChangeAccessFlag) {
field.setAccessible(false);
}
// 递归判断字段是否为空值,有任意一个字段有值则跳出循环
if (!isObjectNullOrEmpty(fieldValue)) {
flag = false;
break;
}
}
return flag;
}
}
}
注解使用示范
@ResponseBody
@GetMapping("/success")
@ApiLock(waitMills = 5000, expireMills = 5000, requiredRequestAttrs = "token", requiredHeaders = "h")
public String success(@RequestParam("id") long id, HttpServletRequest request) {
System.out.println("接口收到" + id + request.getAttribute("token") + request.getHeader("h"));
userService.queryById(1);
return "success";
}