package com.wujie.growth.awardcenter.repository.plugin;
import com.cxyx.common.db.config.DbConfig;
import com.cxyx.common.rpc.Context;
import com.cxyx.common.util.ContextUtil;
import com.cxyx.common.util.JsonUtil;
import com.cxyx.common.util.PropertyUtil;
import com.xiaoju.apollo.message.StringUtil;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.session.defaults.DefaultSqlSession;
import org.springframework.stereotype.Component;
import java.io.StringReader;
import java.lang.reflect.Field;
import java.text.SimpleDateFormat;
import java.util.Collection;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import static com.wujie.growth.awardcenter.common.constant.AwardCenterConstant.PRINT_MYSQL_LOG_THREAD_LOCAL;
/**
@author xuan
-
@create 2020-12-20 16:37
**/
@Setter
@Slf4j
@Intercepts({
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
})
@Component
public class MybatisLogPlugin implements Interceptor {private static final CCJSqlParserManager PARSER_MANAGER = new CCJSqlParserManager();
private DbConfig dbConfig = PropertyUtil.newInstance(DbConfig.class).orElse(new DbConfig());
private static String FOREACH_PREFIX = "_frch";
/**
- 积攒过多就悄悄丢弃
*/
private static ThreadPoolExecutor pool = new ThreadPoolExecutor(2, 4, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(1000), new ThreadPoolExecutor.DiscardPolicy());
@Override
public Object intercept(Invocation invocation) throws Throwable {
long startTime = System.currentTimeMillis();
boolean errorFlag = false;
Object result = null;
try {
result = invocation.proceed();
} catch (Throwable e) {
errorFlag = true;
result = e;
throw e;
} finally {
try {
//优先读取打印日志的后门标记(不区分增删改查-都打印):不要清理printMysqlLogFlag,否则一次request请求只会打印第一条SQL
Boolean printMysqlLogFlag = PRINT_MYSQL_LOG_THREAD_LOCAL.get();
if (printMysqlLogFlag != null && printMysqlLogFlag) {
printSql(invocation, startTime, errorFlag, result);
} else if ("update".equals(invocation.getMethod().getName())) {
printSql(invocation, startTime, errorFlag, result);
}
} catch (Exception e) {
log.error("mybatis-log-plugin error", e);
}
}
return result;
} - 积攒过多就悄悄丢弃
private void printSql(Invocation invocation, long startTime, boolean errorFlag, Object result) {
Map<String, Object> contextParams = ContextUtil.getAll();
pool.execute(() -> {
ContextUtil.putAll(contextParams);
Object parameter = invocation.getArgs()[1];
BoundSql boundSql = ((MappedStatement) invocation.getArgs()[0]).getBoundSql(parameter);
String sql = boundSql.getSql();
// 格式化Sql语句,去除换行符,替换参数
sql = formatSql(sql, boundSql);
//分表时表名替换
sql = shardingTableReplaceName(sql);
if (errorFlag) {
log.error("sql=[{}]||result={}||proc_time={}", sql, JsonUtil.toString(result), System.currentTimeMillis() - startTime);
} else {
log.info("sql=[{}]||result={}||proc_time={}", sql, JsonUtil.toString(result), System.currentTimeMillis() - startTime);
}
});
}
private String shardingTableReplaceName(String sql) {
Set<String> tableNames = getTableNames(sql);
if (tableNames.size() == 1) {
final String tableName = tableNames.iterator().next();
String newTableName = tableName;
if (tableNeedSharding(tableName)) {
newTableName = setSubfixSharding(tableName);
}
//压测数据暂时不做处理
if (!org.apache.commons.lang3.StringUtils.equals(tableName, newTableName)) {
sql = replace(sql, tableName, newTableName);
}
}
return sql;
}
@Override
public Object plugin(Object target) {
return Plugin.wrap(target, this);
}
@Override
public void setProperties(Properties properties) {
}
@SuppressWarnings("unchecked")
private String formatSql(String sql, BoundSql boundSql) {
if (StringUtil.isBlank(sql)) {
return "";
}
sql = beautifySql(sql);
List<ParameterMapping> parameterMappingList = boundSql.getParameterMappings();
Object parameterObject = boundSql.getParameterObject();
if (parameterObject == null || CollectionUtils.isEmpty(parameterMappingList)) {
return sql;
}
String sqlWithoutReplacePlaceholder = sql;
try {
Class<?> parameterObjectClass = parameterObject.getClass();
// 如果参数是StrictMap且Value类型为Collection,获取key="list"的属性,这里主要是为了处理<foreach>循环时传入List这种参数的占位符替换
// 例如select * from xxx where id in <foreach collection="list">...</foreach>
if (isStrictMap(parameterObjectClass)) {
DefaultSqlSession.StrictMap<Collection<?>> strictMap = (DefaultSqlSession.StrictMap<Collection<?>>) parameterObject;
if (isList(strictMap.get("list").getClass())) {
sql = handleListParameter(sql, strictMap.get("list"));
}
} else if (isMap(parameterObjectClass)) {
// 如果参数是Map则直接强转,通过map.get(key)方法获取真正的属性值
// 这里主要是为了处理<insert>、<delete>、<update>、<select>时传入parameterType为map的场景
Map<?, ?> paramMap = (Map<?, ?>) parameterObject;
sql = handleMapParameter(sql, paramMap, parameterMappingList);
} else if (parameterObject instanceof MapperMethod.ParamMap) {
//mybatis-plus特殊处理
MapperMethod.ParamMap map = (MapperMethod.ParamMap) parameterObject;
sql = handleAbstractWrapperParameter(sql, parameterMappingList, map, boundSql);
} else {
// 通用场景,比如传的是一个自定义的对象或者八种基本数据类型之一或者String
sql = handleCommonParameter(sql, parameterMappingList, parameterObjectClass, parameterObject);
}
} catch (Exception e) {
// 占位符替换过程中出现异常,则返回没有替换过占位符但是格式美化过的sql,这样至少保证sql语句比BoundSql中的sql更好看
return sqlWithoutReplacePlaceholder;
}
return sql;
}
/**
* 处理mybatis-plus特殊的AbstractWrapper
*/
private String handleAbstractWrapperParameter(String sql, List<ParameterMapping> parameterMappingList, MapperMethod.ParamMap map, BoundSql boundSql) throws Exception {
for (ParameterMapping mapping : parameterMappingList) {
String property = mapping.getProperty();
if (StringUtils.isNotBlank(property)) {
Object valueObj;
String value;
if (property.startsWith(FOREACH_PREFIX)) { //mybatis <foreach> 批量操作
valueObj = boundSql.getAdditionalParameter(property);
} else {
//mybatis原生update操作:map里直接装fieldName
valueObj = map.get(property);
}
//"?"替换
if (valueObj == null) {
value = "null";
} else if (valueObj instanceof String) {
value = "'" + valueObj.toString() + "'";
} else if (valueObj instanceof Date) {
value = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(valueObj);
} else {
value = valueObj.toString();
}
sql = sql.replaceFirst("\\?", value);
}
}
return sql;
}
/**
* 美化Sql
*/
private String beautifySql(String sql) {
// sql = sql.replace("\n", "").replace("\t", "").replace(" ", " ").replace("( ", "(").replace(" )", ")").replace(" ,", ",");
sql = sql.replaceAll("[\\s\n ]+", " ");
return sql;
}
/**
* 处理参数为List的场景
*/
private String handleListParameter(String sql, Collection<?> col) {
if (col != null && col.size() != 0) {
for (Object obj : col) {
String value = "null";
Class<?> objClass = obj.getClass();
// 只处理基本数据类型、基本数据类型的包装类、String这三种
// 如果是复合类型也是可以的,不过复杂点且这种场景较少,写代码的时候要判断一下要拿到的是复合类型中的哪个属性
if (isPrimitiveOrPrimitiveWrapper(objClass)) {
value = obj.toString();
} else if (objClass.isAssignableFrom(String.class)) {
value = "\"" + obj.toString() + "\"";
}
sql = sql.replaceFirst("\\?", value);
}
}
return sql;
}
/**
* 处理参数为Map的场景
*/
private String handleMapParameter(String sql, Map<?, ?> paramMap, List<ParameterMapping> parameterMappingList) {
for (ParameterMapping parameterMapping : parameterMappingList) {
Object propertyName = parameterMapping.getProperty();
Object propertyValue = paramMap.get(propertyName);
if (propertyValue != null) {
if (propertyValue.getClass().isAssignableFrom(String.class)) {
propertyValue = "\"" + propertyValue + "\"";
}
sql = sql.replaceFirst("\\?", propertyValue.toString());
}
}
return sql;
}
/**
* 处理通用的场景
*/
private String handleCommonParameter(String sql, List<ParameterMapping> parameterMappingList, Class<?> parameterObjectClass, Object parameterObject) throws Exception {
for (ParameterMapping parameterMapping : parameterMappingList) {
String propertyValue;
// 基本数据类型或者基本数据类型的包装类,直接toString即可获取其真正的参数值,其余直接取paramterMapping中的property属性即可
if (isPrimitiveOrPrimitiveWrapper(parameterObjectClass)) {
propertyValue = parameterObject.toString();
} else {
String propertyName = parameterMapping.getProperty();
Field field = parameterObjectClass.getDeclaredField(propertyName);
// 要获取Field中的属性值,这里必须将私有属性的accessible设置为true
field.setAccessible(true);
propertyValue = String.valueOf(field.get(parameterObject));
if (parameterMapping.getJavaType().isAssignableFrom(String.class)) {
propertyValue = "\"" + propertyValue + "\"";
}
}
sql = sql.replaceFirst("\\?", propertyValue);
}
return sql;
}
/**
* 是否基本数据类型或者基本数据类型的包装类
*/
private boolean isPrimitiveOrPrimitiveWrapper(Class<?> parameterObjectClass) {
return parameterObjectClass.isPrimitive() ||
(parameterObjectClass.isAssignableFrom(Byte.class) || parameterObjectClass.isAssignableFrom(Short.class) ||
parameterObjectClass.isAssignableFrom(Integer.class) || parameterObjectClass.isAssignableFrom(Long.class) ||
parameterObjectClass.isAssignableFrom(Double.class) || parameterObjectClass.isAssignableFrom(Float.class) ||
parameterObjectClass.isAssignableFrom(Character.class) || parameterObjectClass.isAssignableFrom(Boolean.class));
}
/**
* 是否DefaultSqlSession的内部类StrictMap
*/
private boolean isStrictMap(Class<?> parameterObjectClass) {
return parameterObjectClass.isAssignableFrom(DefaultSqlSession.StrictMap.class);
}
/**
* 是否List的实现类
*/
private boolean isList(Class<?> clazz) {
Class<?>[] interfaceClasses = clazz.getInterfaces();
for (Class<?> interfaceClass : interfaceClasses) {
if (interfaceClass.isAssignableFrom(List.class)) {
return true;
}
}
return false;
}
/**
* 是否Map的实现类
*/
private boolean isMap(Class<?> parameterObjectClass) {
Class<?>[] interfaceClasses = parameterObjectClass.getInterfaces();
for (Class<?> interfaceClass : interfaceClasses) {
if (interfaceClass.isAssignableFrom(Map.class)) {
return true;
}
}
return false;
}
private Set<String> getTableNames(String sql) {
Set<String> set = new HashSet<>();
try {
Statement statement = PARSER_MANAGER.parse(new StringReader(sql));
TablesNamesFinder finder = new TablesNamesFinder();
List<String> list = finder.getTableList(statement);
if (list != null) {
set.addAll(list);
}
} catch (JSQLParserException e) {
//do nothing
}
return set;
}
private boolean tableNeedSharding(String tableName) {
// if (CollectionUtils.isEmpty(dbConfig.getShardingTables())) {
// return false;
// }
// return dbConfig.getShardingTables().contains(tableName);
return false;
}
private static String replace(String sql, String oldTable, String newTable) {
return sql.replaceAll("(\\b|,)" + oldTable + "(\\b|,|\\.)", "$1" + newTable + "$2");
}
private String setSubfixSharding(String tableName) {
Object contextObj = ContextUtil.get("context");
if (contextObj instanceof Context) {
Map map = ((Context) contextObj).getContext();
if (map != null) {
Object cityIdObj = map.get("cityid");
if (cityIdObj instanceof Integer) {
Integer cityId = (Integer) cityIdObj;
if (cityId != -1) {
return tableName + "_" + cityId;
}
}
}
}
return tableName;
}
}