简单了解下redis嵌入lua脚本(随便百度扒的):
Redis支持的LUA脚本与其优势
redis嵌入lua官方文档
Redis悲观锁、乐观锁和调用Lua脚本的优缺点
序言
本教程本着人和代码其中一个能跑就行的原则。人菜勿喷,只接受技术性建议。
近期看到了一些关于redis的文章说:哇redis牛皮、redis怎么这么快、哦哟!redis6竟然支持多线程了。。。然后问题来了:多线程?线程安全问题?哦哟,百度百度了解了下,原来如此,此多线程非彼多线程(自己百度去,这里不说这个)。然后就看到了嵌入lua脚本这个骚操作,在继续百度了下,哎哟,不错哟!拿来吧你。
蹭个热度:吴签和某时间管理大师参与多人运动(多线程)时已出现一个很细小(线程安全)的针眼,我的有点大,请你忍耐一下,一夜之间让花季程序猿痛苦流泪。
正文
通过redis嵌入lua脚本,实现简单的限流、黑名单功能。别说这也没有那也没有,个性化功能自行开发。
其中有点小坑,最后再说。
说那么多,不如直接丢代码。Talk is cheap. Show me the code.
测试环境
win11
jdk8
Redis server v=5.0.9
springboot 2.4.7
Show me the code.
先来看看lua脚本
lua脚本存放在项目的resource目录下的lua文件夹下面(路径可以自己改,下面SelfRedisScript .java里面改成对应的就行)
--- lua脚本:限流、黑名单专用,慎改
--- 用于高并发情况下保证redis线程安全
--- 注意:
--- 1、redis反序列化问题
--- 2、完成lua脚本后,请在本地测试无误后再提交代码
--- 3、若lua脚本执行报错,redis不会回滚已经执行的命令
-- 获取传递进来的参数
local countKey = KEYS[1]
if countKey == nil then
return true
end
-- 获取传递进来的阈值
local requestCount = KEYS[2]
-- 获取传递进来的过期时间ttl
local requestTtl = KEYS[3]
-- 获取redis参数
local countVal = redis.call('GET', countKey)
-- 如果不是第一次请求
if countVal then
-- 由于lua脚本接收到参数都会转为String,所以要转成数字类型才能比较
local numCountVal = tonumber(countVal)
-- 如果超过指定阈值,则返回true
if numCountVal >= tonumber(requestCount) then
return true
else
numCountVal = numCountVal + 1
redis.call('SETEX', countKey, requestTtl, numCountVal)
end
else
redis.call('SETEX', countKey, requestTtl, 1)
end
return false
Java代码(SelfRedisScript .java)注入RedisScript
@Component
public class SelfRedisScript {
@Bean("redisScriptBoolean")
public DefaultRedisScript<Boolean> redisScriptBoolean() {
DefaultRedisScript<Boolean> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit_blacklisted.lua")));
redisScript.setResultType(Boolean.class);
return redisScript;
}
}
Java代码(RedisTemplateConfig.java)简单配置RedisTemplate
@EnableCaching
@Configuration
@AutoConfigureBefore(RedisAutoConfiguration.class)
public class RedisTemplateConfig {
@Bean
@Primary
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>();
redisTemplate.setKeySerializer(new StringRedisSerializer());
redisTemplate.setHashKeySerializer(new StringRedisSerializer());
redisTemplate.setValueSerializer(new StringRedisSerializer());
redisTemplate.setHashValueSerializer(new StringRedisSerializer());
redisTemplate.setConnectionFactory(redisConnectionFactory);
return redisTemplate;
}
}
准备工作做完后,开始实现简单的限流、黑名单
在过滤器里面实现功能
CommonConstants类中的常量、ApplicationConfig从application.yml从获取值的代码就不贴出来了,
WebUtils.returnResponse单独列出来了,自行修改补充。
RedisConstants常量:/** * 限流机制key前缀 REQUEST_LIMIT:127.0.0.1:/api/test */ String REQUEST_LIMIT = "REQUEST_LIMIT:%s:%s"; /** * 黑名单机制key前缀 REQUEST_LIMIT:127.0.0.1: */ String REQUEST_BLACKLISTED = "REQUEST_BLACKLISTED:%s";
@Slf4j
public class RequestLimitFilter implements Filter {
private final ApplicationConfig applicationConfig;
private Long limitTimeSeconds;
private Integer limitCount;
private Long blacklistedTimeSeconds;
private Integer blacklistedCount;
private List<String> limitIgnores;
private RedisTemplate<String, Object> redisTemplate;
private DefaultRedisScript<Boolean> script;
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
String requestURI = request.getRequestURI();
if (!WebUtils.uriMatch(this.limitIgnores, requestURI)) {
// 获取ip
String realIp = WebUtils.getIP(request);
// 黑名单限制
String blacklistedKey = String.format(RedisConstants.REQUEST_BLACKLISTED, realIp);
// key为空返回true,超过指定阈值返回true,其他返回false
Boolean blackPass = getPass(blacklistedKey, blacklistedCount, blacklistedTimeSeconds);
if (blackPass) {
WebUtils.returnResponse(response, JSONUtil.toJsonStr(R.failed(StatusCode.BLACKLISTED)));
return;
}
// 限流限制
String limitKey = String.format(RedisConstants.REQUEST_LIMIT, realIp, requestURI);
Boolean limitPass = getPass(limitKey, limitCount, limitTimeSeconds);
if (limitPass) {
WebUtils.returnResponse(response, JSONUtil.toJsonStr(R.failed(StatusCode.LIMITED)));
return;
}
}
filterChain.doFilter(servletRequest, servletResponse);
}
@Override
public void destroy() {
Filter.super.destroy();
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
// 限流
this.limitCount = ObjectUtil.isNull(applicationConfig.getLimitCount())
? CommonConstants.REQUEST_LIMIT_COUNT : applicationConfig.getLimitCount();
this.limitTimeSeconds = ObjectUtil.isNull(applicationConfig.getLimitTimeSeconds())
? CommonConstants.REQUEST_LIMIT_TIME_SECONDS : applicationConfig.getLimitTimeSeconds();
// 黑名单
this.blacklistedCount = ObjectUtil.isNull(applicationConfig.getBlacklistedCount())
? CommonConstants.REQUEST_BLACKLISTED_COUNT : applicationConfig.getBlacklistedCount();
this.blacklistedTimeSeconds = ObjectUtil.isNull(applicationConfig.getBlacklistedTimeSeconds())
? CommonConstants.REQUEST_BLACKLISTED_TIME_SECONDS : applicationConfig.getBlacklistedTimeSeconds();
// 过滤请求,从application.yml从获取值
this.limitIgnores = IterUtil.isEmpty(applicationConfig.getLimitIgnores())
? Collections.emptyList() : applicationConfig.getLimitIgnores();
// lua
this.redisTemplate = SpringContextHolder.getBean(RedisTemplate.class);
this.script = SpringContextHolder.getBean("redisScriptBoolean");
Filter.super.init(filterConfig);
}
public RequestLimitFilter(ApplicationConfig applicationConfig) {
this.applicationConfig = applicationConfig;
}
/**
* 调用lua脚本,获取执行结果
* @param key 缓存key
* @param count 请求阈值
* @param timeSeconds 拦截时间
* @return 执行结果
*/
private Boolean getPass(String key, Integer count, Long timeSeconds) {
Boolean execute = redisTemplate.execute(script, Arrays.asList(key, String.valueOf(count), String.valueOf(timeSeconds)));
return execute == null ? true : execute;
}
}
// -------------------------------------------WebUtils工具类---------------------------------------------
public void returnResponse(HttpServletResponse response, String data) {
response.setCharacterEncoding("UTF-8");
response.setContentType("text/html; charset=utf-8");
try (PrintWriter writer = response.getWriter()) {
// 通过 PrintWriter 将 data 数据直接 print 回去
writer.print(data);
} catch (IOException ignored) {
}
}
public String getIP(HttpServletRequest request) {
Assert.notNull(request, "HttpServletRequest is null");
String ip = request.getHeader(HEADER_X_REQUESTED_FOR);
if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HEADER_X_FORWARDED_FOR);
}
if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HEADER_PROXY_CLIENT_IP);
}
if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HEADER_WL_PROXY_CLIENT_IP);
}
if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HEADER_HTTP_CLIENT_IP);
}
if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getHeader(HEADER_HTTP_X_FORWARDED_FOR);
}
if (StrUtil.isBlank(ip) || UNKNOWN.equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return StrUtil.isBlank(ip) ? null : ip.split(",")[0];
}
最后注册下RequestLimitFilter.java这个过滤器
@Component
@AllArgsConstructor
public class FilterRegistration {
private final ApplicationConfig applicationConfig;
@Bean
public FilterRegistrationBean<RequestLimitFilter> requestLimitFilter() {
FilterRegistrationBean<RequestLimitFilter> registration = new FilterRegistrationBean<>();
registration.setFilter(new RequestLimitFilter(applicationConfig));
registration.addUrlPatterns("/*");
registration.setName("RequestLimitFilter");
registration.setOrder(1);
return registration;
}
}
展示下成果(计算规则自行调整)
请求即记录
时间段内多次请求达到限流指定的请求阈值
时间段内多次请求已被限流后,继续请求达到黑名单指定的请求阈值
注意事项
- RedisTemplate配置的序列化问题
如果配置的是JdkSerializationRedisSerializer,就需要改成StringRedisSerializer,如果需要两者兼容,那
就再给spring丢一个名为jdkRedisSerializer的Bean,然后在 @Autowired时,添加@Qualifier("jdkRedisSerializer")指定注入Bean@Bean("jdkRedisSerializer") public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) { RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>(); redisTemplate.setKeySerializer(new StringRedisSerializer()); redisTemplate.setHashKeySerializer(new StringRedisSerializer()); redisTemplate.setValueSerializer(new JdkSerializationRedisSerializer()); redisTemplate.setHashValueSerializer(new JdkSerializationRedisSerializer()); redisTemplate.setConnectionFactory(redisConnectionFactory); return redisTemplate; }
- lua脚本执行报错问题
若lua脚本执行报错,redis不会回滚已经执行的命令,所以在完成lua脚本后,请在本地测试无误后再提交代码