- 拦截器简介
MyBatis提供了一种插件(plugin)的功能,但其实这是拦截器功能。基于这个拦截器我们可以选择在这些被拦截的方法执行前后加上某些逻辑或者在执行这些被拦截的方法时执行自己的逻辑。
这点跟spring的拦截器是基本一致的。它的设计初衷就是为了供用户在某些时候可以实现自己的逻辑而不必去动Mybatis固有的逻辑。拦截器的使用中,分页插件应该是使用得最多的了。分表的实现也差不多类似。
- 首先引入需要的包
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis-spring</artifactId>
<version>1.2.3</version>
</dependency>
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis</artifactId>
<version>3.3.0</version>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.0.27</version>
</dependency>
- 其次实现org.apache.ibatis.plugin.Interceptor接口,复写以下三个方法:
/**
* @Title:
* @Auther: hangyu
* @Date: 2019/4/15
* @Description
* @Version:1.0
*/
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})})
public class TableSegInterceptor implements Interceptor {
private Log log = LogFactory.getLog(getClass());
private final static String BOUNDSQL_SQL_NAME = "delegate.boundSql.sql";
private final static String BOUNDSQL_NAME = "delegate.boundSql";
private final static String MAPPEDSTATEMENT_NAME = "delegate.mappedStatement";
private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();
private final static ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
//全局操作对象
MetaObject metaObject = MetaObject.forObject(statementHandler, DEFAULT_OBJECT_FACTORY,
DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
//获取原始sql
String originalSql = (String) metaObject.getValue(BOUNDSQL_SQL_NAME);
//这两个对象都是获取mapper的参数的
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue(MAPPEDSTATEMENT_NAME);
BoundSql boundSql = (BoundSql) metaObject.getValue(BOUNDSQL_NAME);
if (StringUtils.isNotEmpty(originalSql)) {
String id = mappedStatement.getId();
String className = id.substring(0, id.lastIndexOf("."));
Class<?> classObj = Class.forName(className);
TableSeg tableSeg = classObj.getAnnotation(TableSeg.class);
if (tableSeg != null) {
Map<String, Object> parameter = getParameterFromMappedStatement(mappedStatement, boundSql);
shardTable(metaObject, parameter, tableSeg, originalSql);
}
}
return invocation.proceed();
}
/**
* 获取参数
*
* @param ms
* @param boundSql
* @return
*/
private Map<String, Object> getParameterFromMappedStatement(MappedStatement ms, BoundSql boundSql) {
Map<String, Object> paramMap;
Object parameterObject = boundSql.getParameterObject();
if (parameterObject == null) {
paramMap = new HashMap<String, Object>();
} else if (parameterObject instanceof Map) {
paramMap = new HashMap<String, Object>();
paramMap.putAll((Map) parameterObject);
} else {
paramMap = new HashMap<String, Object>();
boolean hasTypeHandler = ms.getConfiguration().getTypeHandlerRegistry()
.hasTypeHandler(parameterObject.getClass());
MetaObject metaObject = SystemMetaObject.forObject(parameterObject);
if (!hasTypeHandler) {
for (String name : metaObject.getGetterNames()) {
paramMap.put(name, metaObject.getValue(name));
}
}
//下面这段方法,主要解决一个常见类型的参数时的问题
if (boundSql.getParameterMappings() != null && boundSql.getParameterMappings().size() > 0) {
for (ParameterMapping parameterMapping : boundSql.getParameterMappings()) {
String name = parameterMapping.getProperty();
if (paramMap.get(name) == null) {
if (hasTypeHandler || parameterMapping.getJavaType().equals(parameterObject.getClass())) {
paramMap.put(name, parameterObject);
break;
}
}
}
}
}
return paramMap;
}
/**
* 分表操作(不可用于批量语句)
*
* @param metaObject
* @param tableSeg
* @param originalSql
* @throws Exception
*/
private void shardTable(MetaObject metaObject, Map<String, Object> parameter,
TableSeg tableSeg, String originalSql){
MySqlStatementParser parser = new MySqlStatementParser(originalSql);
SQLStatement statement = parser.parseStatement();
StringBuilder newSql = new StringBuilder();
SQLASTOutputVisitor visitor = SQLUtils.createOutputVisitor(newSql, JdbcConstants.MYSQL);
Map<String, String> oldTableNewTableNameMap = getShardTableName(tableSeg, parameter);
if(!oldTableNewTableNameMap.isEmpty()) {
for (Map.Entry<String, String> entry : oldTableNewTableNameMap.entrySet()) {
// 增加旧标明和新表名映射关系
visitor.addTableMapping(entry.getKey(), entry.getValue());
}
}
statement.accept(visitor);
//重新赋值新sql生效
metaObject.setValue(BOUNDSQL_SQL_NAME, newSql.toString());
}
/**
* 构造分表表名映射
* @param seg
* @param parameter
* @return
*/
private Map<String, String> getShardTableName(TableSeg seg, Map<String, Object> parameter) {
TableShardStrategy tableShardStrategy = seg.shardBy();
// 分表code
String memberIdStr = parameter.get(tableShardStrategy.getShardCode()).toString();
Long memberId = Long.valueOf(memberIdStr);
// 分表表名,可以针对多种类型做分表
String[] toShardTableList = tableShardStrategy.getShardTableList();
// 新老表名map,key:老表名 value:新表名
Map<String, String> oldTableNewTableNameMap = new HashMap<>();
String suffix;
for (String toShardTable : toShardTableList) {
//取模
suffix = String.valueOf(memberId % seg.shardNum());
StringBuilder shardTableName = new StringBuilder();
//添加后缀
oldTableNewTableNameMap.put(toShardTable, shardTableName.append(toShardTable).append("_").append(suffix).toString());
}
return oldTableNewTableNameMap;
}
@Override
public Object plugin(Object target) {
// 当目标类是StatementHandler类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的次数
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
} else {
return target;
}
}
@Override
public void setProperties(Properties properties) {
}
}
- 自定义分表枚举,包含分表表数量,分表code字段,分表表名这三个必备属性
/**
* @Title:
* @Auther: hangyu
* @Date: 2019/4/15
* @Description
* @Version:1.0
*/
@Target({ElementType.TYPE })
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface TableSeg {
/**
* 分表方式,取模,如%4:表示取4余数,
* 如果不设置,直接根据shardNum值分表
* @return
*/
int shardNum();
/**
* 根据什么字段分表
* @return
*/
TableShardStrategy shardBy();
}
/**
* @Title:
* @Auther: hangyu
* @Date: 2019/4/15
* @Description
* @Version:1.0
*/
public enum TableShardStrategy {
OPEN_ID("openId", new String[]{"member"});
// 分表code
private String shardCode;
// 分表表名
private String[] shardTableList;
TableShardStrategy(String shardCode, String[] shardTableList) {
this.shardCode = shardCode;
this.shardTableList = shardTableList;
}
public String getShardCode() {
return shardCode;
}
public void setShardCode(String shardCode) {
this.shardCode = shardCode;
}
public String[] getShardTableList() {
return shardTableList;
}
public void setShardTableList(String[] shardTableList) {
this.shardTableList = shardTableList;
}
}
- 最后在mapper中引入注解
public class Member implements Serializable {
private Long memberId;
private String openId;
public Long getMemberId() {
return memberId;
}
public void setMemberId(Long memberId) {
this.memberId = memberId;
}
public String getOpenId() {
return openId;
}
public void setOpenId(String openId) {
this.openId = openId;
}
}
/**
* @Title:
* @Auther: hangyu
* @Date: 2019/4/15
* @Description
* @Version:1.0
*/
@Repository
@TableSeg(shardNum = 100, shardBy = TableShardStrategy.OPEN_ID)
public interface MemberDao {
Member getMember(String openId);
int insert(Member member);
}
- 最后来一种直接替换表后缀添加时间的比较简单方式
/**
* @Title:
* @Auther: hangyu
* @Date: 2019/4/15
* @Description
* @Version:1.0
*/
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})})
public class TableShareInterceptor implements Interceptor {
private Log log = LogFactory.getLog(getClass());
private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();
private final static ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();
private static final String DATE_PATTERN = "yyyyMMdd";
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
//全局操作对象
MetaObject metaObject = MetaObject.forObject(statementHandler, DEFAULT_OBJECT_FACTORY,
DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
MappedStatement mappedStatement = (MappedStatement)
metaObject.getValue("delegate.mappedStatement");
String id = mappedStatement.getId();
id = id.substring(0, id.lastIndexOf('.'));
Class clazz = Class.forName(id);
// 获取TableShard注解
TableSeg tableShard = (TableSeg)clazz.getAnnotation(TableSeg.class);
if ( tableShard != null ) {
TableShardStrategy tableShardStrategy = tableShard.shardBy();
String tableName = tableShardStrategy.getShardTableList()[0];
String newTableName = tableShard(tableName);
// 获取源sql
String sql = (String)metaObject.getValue("delegate.boundSql.sql");
// 用新sql代替旧sql, 完成所谓的sql rewrite
metaObject.setValue("delegate.boundSql.sql", sql.replaceAll(tableName, newTableName));
}
// 传递给下一个拦截器处理
return invocation.proceed();
}
public String tableShard(String tableName) {
SimpleDateFormat sdf = new SimpleDateFormat(DATE_PATTERN);
return tableName + "_" + sdf.format(new Date());
}
@Override
public Object plugin(Object o) {
return null;
}
@Override
public void setProperties(Properties properties) {
}
}