前言
提到插件模式,我们可能很陌生,但是配合着Mybatis我们能瞬间想到大名鼎鼎的应用场景——分页插件。在MyBatis中插件模式的应用十分广泛,需要我们对插件模式深入研究,参能摸透其中的奥秘,对于插件不熟悉的同学,希望不要擅自使用插件,插件的执行会更改原有的目标代码的逻辑,可能会产生不确定的问题。我在学习的插件模式的时候,对它无比的亲切,觉得就是我们Spring中的AOP呀。只是又觉得和AOP有些不同,但是不耽误我们用AOP的思想去理解它的原理。
结尾,我再结合PageHelper进行实战,手写一个分页插件,验证对插件模式的学习。
学习目标
- Mybatis的插件是如何实现的?
- Mybatis的插件对于哪些类有效?
- Mybatis插件的应用场景有哪些?
Mybatis插件的简单使用
编写插件
完整代码如下:
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import java.util.*;
import org.junit.jupiter.api.Test;
class PluginTest {
@Intercepts(
@Signature(
type = Map.class,
method = "get",
args = { Object.class }
)
)
public static class HelloWorldMapPlugin implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
return "Hello World,"+invocation.getMethod().invoke(invocation.getTarget(),invocation.getArgs());
}
}
public static class HelloWorldMap<K,V> implements Map<K,V>{
private final Map<K,V> delegate = new HashMap<>();
...
@Override
public V get(Object key) {
return delegate.get(key);
}
@Override
public V put(K key, V value) {
return delegate.put(key,value);
}
...
}
@Test
public void testHelloWorldMapPlugin(){
Map<String,String> map = new HelloWorldMap<>();
map.put("anything","wuxuan.chai");
Map<String,String> mapWrap = (Map<String,String>) new HelloWorldMapPlugin().plugin(map);
assertEquals("Hello World,wuxuan.chai",mapWrap.get("anything"));
}
}
第一步:定义Intercepts,用于描述需要拦截的方法签名,包含拦截的方法信息、参数信息、拦截的借口类信息(一定是接口,jdk代理只支持代理接口的实现类,否则无法代理,后面会做分析)
第二步:定义拦截器HelloWorldMapPlugin,重写Interceptor的intercept方法,这里去实现拦截后,处理目标方法的改写逻辑
第三步:调用HelloWorldMapPlugin的plugin方法完成插件的调用,传入目标类,如果目标类包含第一步中定义的方法签名,则会自动生成代理类,执行Interceptor的intercept方法。
Mybatis插件模式的实现原理
了解插件模式实现原理之前,回顾下jdk动态代理:
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class ProxyTests{
public static interface Animal{
void run(String name);
}
@Test
public void jdkProxyTest(){
Animal animalProxy = (Animal)Proxy.newProxyInstance(Animal.class.getClassLoader(), new Class[]{Animal.class}, new InvocationHandler() {
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
String name = args[0].toString();
System.out.println(name + "在泥潭里跑");
return null;
}
});
Assertions.assertDoesNotThrow(()->animalProxy.run("河马"));
}
}
在jdk动态代理模式中,我们重点需要实现的是InvocationHandler的invoke方法,去代理目标方法。
回到Mybatis的插件实现上来,插件是基于jdk的动态代理实现的。
上面插件的例子中,HelloWorldMapPlugin调用了plugin方法,获得了插件的代理类。来看看这里面主要做了什么:
- 调用 Interceptor#plugin方法
default Object plugin(Object target) {
//插件包装目标对象,生成代理对象
return Plugin.wrap(target, this);
}
传入目标对象的实例,获得目标对象的代理对象。
- 调用Plugin#wrap方法
/**
* 根据拦截器和目标对象,生成代理对象的封装
* @param target 目标对象
* @param interceptor 拦截器
* @return 代理对象
*/
public static Object wrap(Object target, Interceptor interceptor) {
//获取类和代理方法的映射关系,签名过程
Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
//代理的目标类
Class<?> type = target.getClass();
//jdk动态代理基于接口代理,所以要递归找到接口层面,否则直接返回目标类,无法代理
Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
if (interfaces.length > 0) {
return Proxy.newProxyInstance(type.getClassLoader(), interfaces, new Plugin(target, interceptor, signatureMap));
}
return target;
}
传入target的目的:因为target是目标对象,我们需要根据他包装生成对应的代理对象
传入interceptor的目的:根据上面的例子,我们定义plugin的时候会定义Intercepts注解,这里面包含了我们对期望插件拦截的方法签名信息,以及我们重写的intercept方法。这是我们拦截器存在的核心意义,简单是理解就是,“拦截哪些”和“做什么”
- 如何处理Intercepts注解中的方法签名
private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
//获取Intercepts注解的内容
Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
// issue #251
if (interceptsAnnotation == null) {
throw new PluginException(
"No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
}
//得到插件的签名列表,里面描述了关于代理类和方法的定义
Signature[] sigs = interceptsAnnotation.value();
Map<Class<?>, Set<Method>> signatureMap = new HashMap<>();
for (Signature sig : sigs) {
//如果一个类映射多个方法,合并到一起
Set<Method> methods = MapUtil.computeIfAbsent(signatureMap, sig.type(), k -> new HashSet<>());
try {
//获取签名中的方法。
/**
* 调用的是 {@code Class.getMethod(String name, Class<?>... parameterTypes)} 获取目标类的代理方法
*/
Method method = sig.type().getMethod(sig.method(), sig.args());
methods.add(method);
} catch (NoSuchMethodException e) {
throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e,
e);
}
}
return signatureMap;
}
这个方法的主要目的是解析Intercepts注解,将拦截器的拦截目标定义给解析合并出来。
- 根据拦截器的定义,匹配目标对象,决定是否代理
private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
Set<Class<?>> interfaces = new HashSet<>();
while (type != null) {
//获取目标类所实现的所有接口,从这里可以看出,Intercepts中的signature的type定义必须是一个interface
for (Class<?> c : type.getInterfaces()) {
//如果接口在签名列表中,则添加到接口列表中,说明这个类符合拦截器拦截(代理的条件)
if (signatureMap.containsKey(c)) {
interfaces.add(c);
}
}
//一次递归,到接口的父类,直到顶级类
type = type.getSuperclass();
}
return interfaces.toArray(new Class<?>[0]);
}
这段代码的意思逻辑看起来比较难理解,仔细品一下,其实不难发现他的奥妙。上面说了,拦截器的主要功能简单概括:“拦截哪些”和“做什么”,这里就是“拦截哪些”的具体体现。如何判断一个类是否会被拦截,这里面根据这个类,判断他是否是我们拦截方法签名中定义的类型的接口实现类(一定是接口实现类),一直递归到类的顶级父类Object为止。
-
判断是否拦截,创建代理对象
如果第4步匹配到符合拦截器定义的类,则通过jdk动态代理创建代理类(第2步中的逻辑)。
Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
if (interfaces.length > 0) {
return Proxy.newProxyInstance(type.getClassLoader(), interfaces, new Plugin(target, interceptor, signatureMap));
}
前面我们回顾了jdk动态代理,知道jdk动态代理的编制点需要自己去实现InvocationHandler,其实Plugin这个类就是InvocationHandler的接口实现类。通过代理类调用目标签名方法,实际上会执行,Plugin#invoke方法
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
//从签名映射中,获取代理的方法
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
if (methods != null && methods.contains(method)) {
//存在代理的方法,则拦截器生效,执行拦截的内容,有些AOP的含义
//@{code Invocation}保存了目标方法的原始信息,包括目标对象,目标方法,目标方法的参数,可以通过Method.invoke(target,args)方法执行目标原始方法
return interceptor.intercept(new Invocation(target, method, args));
}
//如果不存在代理方法,则直接执行目标方法
return method.invoke(target, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
如果执行的方法是拦截器定义的方法,会直接调用插件interceptor.intercept方法,否则调用源方法。
综上,就是Mybatis的插件模式,感觉就是jdk动态代理的一种运用,是动态代理的思想。到此我们学习目标的第一个小目标完成了。
开发Mybatis插件注意些什么
这里主要是解决我们学习目标的第二小目标——Mybatis的插件对于哪些类有效
上面我们或多或少的间接提示了,Mybatis插件模式的实现原理中的第4步中,我提到了拦截方法签名中定义的类型的接口实现类(
一定是接口实现类
)
,也就是说,我们的注解Intercepts的signature注解中type只能是接口。在这里给大家举反例,证实结论:
public class NonInterfacePluginTests{
public static class User {
public String getName() {
return "name";
}
}
@Intercepts(
{
@Signature(type = User.class, method = "getName", args = {})
}
)
public static class UserPlugin implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
return "haha";
}
}
/**
* 反例,当签名的对象没实现接口时,无法代理
*/
@Test
public void testUserPlugin() {
User user = new User();
User userWrap = (User) new UserPlugin().plugin(user);
assertNotEquals("name", userWrap.getName());
}
}
这个例子中插件中定义的方法签名的类型是User,User不是一个接口,通过执行发现拦截器失效。除此之外,jdk动态代理无法代理私有化方内部类或者方法,所以类和方法的访问级别要公开。
Mybatis插件的应用场景有哪些
Mybatis内部的使用
源码中org.apache.ibatis.session.Configuration初始化的时候,加载了mybatis-config.xml后,会将xml中定义的plugin注册到interceptorChain中。
-
ParameterHandler
JDBC预执行的过程中的参数处理,可以增强参数的设置
public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject,
BoundSql boundSql) {
ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement,
parameterObject, boundSql);
return (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
}
-
ResultSetHandler
JDBC执行的结果ResultSet的处理
public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds,
ParameterHandler parameterHandler, ResultHandler resultHandler, BoundSql boundSql) {
ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler,
resultHandler, boundSql, rowBounds);
return (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
}
-
StatementHandler
JDBC的CRUD操作处理
public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement,
Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject,
rowBounds, resultHandler, boundSql);
return (StatementHandler) interceptorChain.pluginAll(statementHandler);
}
-
Executor
Mybatis的数据操作的执行器,包含了CRUD与事务的处理,连接的处理
public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
executorType = executorType == null ? defaultExecutorType : executorType;
Executor executor;
if (ExecutorType.BATCH == executorType) {
executor = new BatchExecutor(this, transaction);
} else if (ExecutorType.REUSE == executorType) {
executor = new ReuseExecutor(this, transaction);
} else {
executor = new SimpleExecutor(this, transaction);
}
if (cacheEnabled) {
executor = new CachingExecutor(executor);
}
return (Executor) interceptorChain.pluginAll(executor);
}
手写一个Mybatis的分页插件
Mybatis-plus和PageHelper是我们日常开发中常用的Mybatis增强工具,他们给我们提供了通用的CRUD以及分页的支持,帮我们节省了造轮子的烦恼。这里我们简单的学习一下分页插件的实现原理。
首先我们要搞清楚,如果一个分页的场景我们需要考虑那些情况,分页的业务SQL怎么写?
场景:
-
分页定义,要考虑哪些?
分页的定义包含:当前页(page)、分页大小(pageSize)、分页结果(records)、总条数(total)除此之外,有的还包括,排序字段(例子中不考虑)等等
-
什么场景需要分页?
总条数 > 0
分页的偏移位置在数据的体量范围内((page-1)pageSize < total*)
分页SQL:mysql的写法
---总条数的统计
select count(*) from your_table;
--- 取分页sql
select * from your_table limit offset_num,limit_size
根据上面的思路,我们一步一步的实现分页插件。
-
定义分页
定义分页定义接口
package org.apache.ibatis.custom.plugin.page;
import java.util.List;
/**
* @author wuxuan chai
* @since 2024/1/12 15:22
*/
public interface IPage<T> {
void setTotal(Long total);
void setRecords(List<T> records);
Long pageSize();
Long page();
}
定义分页的数据层实现
package org.apache.ibatis.custom.plugin.page;
import java.util.List;
/**
* @author wuxuan chai
* @since 2024/1/12 15:25
*/
public class Page<T> implements IPage<T> {
private Long total;
private Long pageSize;
private Long page;
private List<T> records;
public Page(Long pageSize, Long page) {
this.pageSize = pageSize;
this.page = page;
}
...IGNORE GETTER/SETTER/TOSTRING FUCTION
}
- 定义分页插件拦截器
-
定义插件的拦截位置
因为分页只涉及到查询,所以我们只需要拦截执行器的查询接口即可
-
@Intercepts(
{
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})}
)
- 定义查询总条数的逻辑
这里面我们要定义一个新的MappedStatement,用以创建一个count查询,重新定义查询的BoundSQL
private Long count(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, Object parameterObject, BoundSql boundSql) throws SQLException {
MappedStatement.Builder builder = new MappedStatement.Builder(mappedStatement.getConfiguration(), mappedStatement.getId() + ".count", mappedStatement.getSqlSource(), SqlCommandType.SELECT);
builder.resultMaps(List.of(new ResultMap.Builder(mappedStatement.getConfiguration(), "count_res", Long.class, Collections.emptyList()).build()));
builder.fetchSize(mappedStatement.getFetchSize());
builder.timeout(mappedStatement.getTimeout());
builder.cache(mappedStatement.getCache());
builder.flushCacheRequired(mappedStatement.isFlushCacheRequired());
builder.resource(mappedStatement.getResource());
builder.statementType(mappedStatement.getStatementType());
builder.resultSetType(mappedStatement.getResultSetType());
builder.useCache(mappedStatement.isUseCache());
MappedStatement countMappedStatement = builder.build();
String countSql = "select count(*) as total from (" + boundSql.getSql() + ") tmp_count";
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, boundSql.getParameterMappings(), parameterObject);
CacheKey cacheKey = executor.createCacheKey(countMappedStatement, parameterObject, rowBounds, countBoundSql);
List<Object> res = executor.query(countMappedStatement, parameterObject, rowBounds, null, cacheKey, countBoundSql);
return res.isEmpty() ? 0L : (Long) res.get(0);
}
-
增强分页sql的逻辑
根据Page分页中的参数,结合原始查询的BoundSql构建新的BoundSql,构建分页SQL以及设置分页的查询参数。然后执行SQL查询
//重新定义查询的SQL定义西悉尼
List<ParameterMapping> parameterMappings = new ArrayList<>(boundSql.getParameterMappings());
//增加sql预执行参数类型及占位名称
parameterMappings.add(new ParameterMapping.Builder(mappedStatement.getConfiguration(), "offset", Long.class).build());
parameterMappings.add(new ParameterMapping.Builder(mappedStatement.getConfiguration(), "limit", Long.class).build());
//不是mysql的语法,derby内存数据库
BoundSql newBoundSQL = new BoundSql(mappedStatement.getConfiguration(), boundSql.getSql() + " offset ? rows fetch next ? rows only ", parameterMappings, parameterObject);
//设置分页参数值
newBoundSQL.setAdditionalParameter("offset", (iPage.page() - 1) * iPage.pageSize());
newBoundSQL.setAdditionalParameter("limit", iPage.pageSize());
CacheKey cacheKey = executor.createCacheKey(mappedStatement, parameterObject, rowBounds, newBoundSQL);
//执行分页sql
return executor.query(mappedStatement, parameterObject, rowBounds, resultHandler, cacheKey, newBoundSQL);
完整Mapper代码
package org.apache.ibatis.custom.mapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.custom.User;
import org.apache.ibatis.custom.plugin.page.Page;
import java.util.List;
/**
* @author wuxuan chai
* @since 2024/1/12 09:47
*/
public interface UserMapper {
/**
* 分页结果封装
*/
default Page<User> pageQuery(Page<User> page) {
List<User> users = this.selectPageData(page);
page.setRecords(users);
return page;
}
/**
*分页查询的逻辑
*/
List<User> selectPageData(Page<User> page);
}
完整分页插件代码
package org.apache.ibatis.custom.plugin;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.custom.plugin.page.IPage;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.lang.reflect.InvocationTargetException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* @author wuxuan chai
* @since 2024/1/12 15:18
*/
@Intercepts(
{
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})}
)
public class PageHelperInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws InvocationTargetException, IllegalAccessException, SQLException {
Executor executor = (Executor) invocation.getTarget();
Object[] args = invocation.getArgs();
MappedStatement mappedStatement = (MappedStatement) args[0];
Object parameterObject = args[1];
RowBounds rowBounds = (RowBounds) args[2];
ResultHandler resultHandler = (ResultHandler) args[3];
BoundSql boundSql = mappedStatement.getBoundSql(parameterObject);
if (args.length == 5) {
CacheKey cacheKey = (CacheKey) args[4];
boundSql = (BoundSql) args[5];
}
IPage<?> iPage = hasPageArgs(parameterObject);
if (iPage != null) {
if (iPage.pageSize() <= 0 || iPage.page() < 0) {
return Collections.emptyList();
} else {
Long count = count(executor, mappedStatement, rowBounds, parameterObject, boundSql);
if ((iPage.page()-1) * iPage.pageSize() <= count) {
iPage.setTotal(count);
List<ParameterMapping> parameterMappings = new ArrayList<>(boundSql.getParameterMappings());
parameterMappings.add(new ParameterMapping.Builder(mappedStatement.getConfiguration(), "offset", Long.class).build());
parameterMappings.add(new ParameterMapping.Builder(mappedStatement.getConfiguration(), "limit", Long.class).build());
BoundSql newBoundSQL = new BoundSql(mappedStatement.getConfiguration(), boundSql.getSql() + " offset ? rows fetch next ? rows only ", parameterMappings, parameterObject);
newBoundSQL.setAdditionalParameter("offset", (iPage.page() - 1) * iPage.pageSize());
newBoundSQL.setAdditionalParameter("limit", iPage.pageSize());
CacheKey cacheKey = executor.createCacheKey(mappedStatement, parameterObject, rowBounds, newBoundSQL);
return executor.query(mappedStatement, parameterObject, rowBounds, resultHandler, cacheKey, newBoundSQL);
} else {
return Collections.emptyList();
}
}
} else {
return invocation.proceed();
}
}
//构建count查询的MappedStatement
private Long count(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, Object parameterObject, BoundSql boundSql) throws SQLException {
MappedStatement.Builder builder = new MappedStatement.Builder(mappedStatement.getConfiguration(), mappedStatement.getId() + ".count", mappedStatement.getSqlSource(), SqlCommandType.SELECT);
builder.resultMaps(List.of(new ResultMap.Builder(mappedStatement.getConfiguration(), "count_res", Long.class, Collections.emptyList()).build()));
builder.fetchSize(mappedStatement.getFetchSize());
builder.timeout(mappedStatement.getTimeout());
builder.cache(mappedStatement.getCache());
builder.flushCacheRequired(mappedStatement.isFlushCacheRequired());
builder.resource(mappedStatement.getResource());
builder.statementType(mappedStatement.getStatementType());
builder.resultSetType(mappedStatement.getResultSetType());
builder.useCache(mappedStatement.isUseCache());
MappedStatement countMappedStatement = builder.build();
String countSql = "select count(*) as total from (" + boundSql.getSql() + ") tmp_count";
BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, boundSql.getParameterMappings(), parameterObject);
CacheKey cacheKey = executor.createCacheKey(countMappedStatement, parameterObject, rowBounds, countBoundSql);
List<Object> res = executor.query(countMappedStatement, parameterObject, rowBounds, null, cacheKey, countBoundSql);
return res.isEmpty() ? 0L : (Long) res.get(0);
}
private IPage<?> hasPageArgs(Object parameterObject) {
if (parameterObject instanceof Map map) {
return (IPage<?>) map.keySet().stream().filter(key -> key instanceof IPage<?>).findFirst().orElse(null);
} else if (parameterObject instanceof IPage<?>) {
return (IPage<?>) parameterObject;
} else {
return null;
}
}
}
总结
至此,Mybatis的插件模式的实现原理学习目标完成。在Mybatis的源码中,还有很多有意思的设计,接下来会继续学习,并梳理出来。