package com.csw.mybatisSpringboot.config;
import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
import java.util.List;
@Component
public class InterceptRunner implements ApplicationRunner {
@Autowired
private List<SqlSessionFactory> sqlSessionFactoryList;
@Override
public void run(ApplicationArguments args) throws Exception {
DataPermissionInterceptor mybatisInterceptor = new DataPermissionInterceptor();
for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();
configuration.addInterceptor(mybatisInterceptor);
}
}
}
package com.csw.mybatisSpringboot.config.quanXian;
import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface SqlLimit {
/**
* 需要拦截权限的表名
*/
String tableName() default "";
/**
* 是否要立即分页,true 自动拼接limit,否用分页插件或者自定义分页工具,[默认否]
*/
boolean isLimit() default false;
}
package com.csw.mybatisSpringboot.config.quanXian;
import cn.hutool.core.bean.BeanUtil;
import com.csw.mybatisSpringboot.config.PageDto;
import com.csw.mybatisSpringboot.config.exception.BusinessException;
import com.csw.mybatisSpringboot.entity.User;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.select.SelectBody;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@Component
@Intercepts({
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
// @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
})
@Slf4j
public class DataPermissionInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
MappedStatement statement = (MappedStatement) invocation.getArgs()[0];
Object parameter = invocation.getArgs()[1];
BoundSql boundSql = statement.getBoundSql(parameter);
String originalSql = boundSql.getSql();
Object parameterObject = boundSql.getParameterObject();
SqlLimit sqlLimit = isLimit(statement);
if (sqlLimit == null) {
return invocation.proceed();
}
RequestAttributes req = RequestContextHolder.getRequestAttributes();
if (req == null) {
return invocation.proceed();
}
//处理request
HttpServletRequest request = ((ServletRequestAttributes) req).getRequest();
//从request里面拿到用户信息
User userVo = null;
try {
userVo = UserUtils.getUserInfo(request);
} catch (Exception e) {
//从token里面解析失败
throw new BusinessException("用户未登录");
}
//拿到表的名称或者别名
String tableAlias = getTableAliasString(originalSql, sqlLimit.tableName());
//拿到修改后的sql
String sql = addTenantCondition(originalSql, userVo, tableAlias, parameter, sqlLimit.isLimit());
log.info("原SQL:{}, 数据权限替换后的SQL:{}", originalSql, sql);
BoundSql newBoundSql = new BoundSql(statement.getConfiguration(), sql, boundSql.getParameterMappings(), parameterObject);
MappedStatement newStatement = copyFromMappedStatement(statement, new BoundSqlSqlSource(newBoundSql));
invocation.getArgs()[0] = newStatement;
return invocation.proceed();
}
/**
* 获取表的别名
*
* @param originalSql
* @param sqlLimit
* @return
* @throws JSQLParserException
*/
private static String getTableAliasString(String originalSql, String tableAlias) throws JSQLParserException {
Statement statement = CCJSqlParserUtil.parse(originalSql);
if (statement instanceof Select) {
Select select = (Select) statement;
SelectBody selectBody = select.getSelectBody();
if (selectBody instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectBody;
FromItem fromItem = plainSelect.getFromItem();
if (fromItem instanceof Table) {
Table table = (Table) fromItem;
String tableName = table.getName();
Alias alias = table.getAlias();
if (tableAlias.equals(tableName) && alias != null) {
tableAlias = alias.getName();
return tableAlias;
}
System.out.println("Table Name: " + tableName);
}
}
}
return tableAlias;
}
/**
* 重新拼接SQL
*/
private String addTenantCondition(String originalSql, User user, String alias, Object parameter, boolean isLimit) {
StringBuilder sb = new StringBuilder(originalSql);
int index = sb.toString().toLowerCase().indexOf("where");
String fieldSubSystemId = getField(alias, "sub_system_id");
String fieldOrgId = getField(alias, "org_id");
String fieldCreateId = getField(alias, "create_id");
if (user.getUserType().equals("1")) {//超级管理员-查看所有
return originalSql;
} else if (user.getUserType().equals("2")) {//子系统管理员-查看本系统所有
benSystemAll(user, index, sb, fieldSubSystemId);
return sb.toString();
} else if (user.getUserType().equals("3")) {//用户按照数据权限
if (user.getDataAuthId().equals("1001") || user.getDataAuthId() == null) {//本系统下所有
benSystemAll(user, index, sb, fieldSubSystemId);
} else if (user.getDataAuthId().equals("1002")) {
benSystemAll(user, index, sb, fieldSubSystemId);
sb.insert(index + 5, " " + fieldOrgId + " = " + user.getOrgId() + " and ");
} else if (user.getDataAuthId().equals("1003")) {
benSystemAll(user, index, sb, fieldSubSystemId);
//模拟查询出用户部门及以下部门
List listOrg = new ArrayList();
listOrg.add(1);
listOrg.add(2);
listOrg.add(3);
String string = listOrg.toString();
string = string.substring(1, string.length() - 1);
sb.insert(index + 5, " " + fieldOrgId + " in (" + string + ") and ");
} else if (user.getDataAuthId().equals("1004")) {
benSystemAll(user, index, sb, fieldSubSystemId);
sb.insert(index + 5, " " + fieldCreateId + " = " + user.getCreateId() + " and ");
}
}
if (isLimit == true) {//如果需要用limit分页的就传true,要不然就传false,在外面用工具类分
//拿到分页参数,如果数据库用的是limit风格
Map map = BeanUtil.beanToMap(parameter);
PageDto dto = BeanUtil.copyProperties(map.get("dto"), PageDto.class);
//第一页在数据库里面是0
sb.append(" limit ").append(dto.getPageNo() - 1 + "," + dto.getPageSize());
}
return sb.toString();
}
private static void benSystemAll(User user, int index, StringBuilder sb, String fieldSubSystemId) {
if (index < 0) {
sb.append(" where ").append(fieldSubSystemId).append(" = ").append(user.getSubSystemId());
} else {
sb.insert(index + 5, " " + fieldSubSystemId + " = " + user.getSubSystemId() + " and ");
}
}
private static String getField(String alias, String field) {
if (StringUtils.isNoneBlank(alias)) {
field = alias + "." + field;
}
return field;
}
private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.cache(ms.getCache());
builder.useCache(ms.isUseCache());
return builder.build();
}
/**
* 通过注解判断是否需要限制数据
*
* @return
*/
private SqlLimit isLimit(MappedStatement mappedStatement) {
SqlLimit sqlLimit = null;
try {
String id = mappedStatement.getId();
String className = id.substring(0, id.lastIndexOf("."));
String methodName = id.substring(id.lastIndexOf(".") + 1, id.length());
final Class<?> cls = Class.forName(className);
final Method[] method = cls.getMethods();
for (Method me : method) {
if (me.getName().equals(methodName) && me.isAnnotationPresent(SqlLimit.class)) {
sqlLimit = me.getAnnotation(SqlLimit.class);
return sqlLimit;
}
}
} catch (Exception e) {
e.printStackTrace();
}
return sqlLimit;
}
public static class BoundSqlSqlSource implements SqlSource {
private final BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
}
package com.csw.mybatisSpringboot.config;
import lombok.Data;
@Data
//@ApiModel("分页实体")
public class PageDto {
//@ApiModelProperty("当前页码")
private Integer pageNo = 1;
//@ApiModelProperty("每页显示条数")
private Integer pageSize = 10;
}
package com.csw.mybatisSpringboot.dto;
import com.csw.mybatisSpringboot.config.PageDto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
@Data
@ApiModel("列表查询入参")
public class UserListDto extends PageDto {
@ApiModelProperty("名字")
private String name;
}
package com.csw.mybatisSpringboot.config;
import lombok.Data;
import java.util.List;
/**
* 分页查询信息传递类
*
* @param <T>
*/
@Data
public class PageResult<T> {
private List<T> items;
private Long total;
private Long totalPage;
private Integer pageNo;
private Integer pageSize;
}
package com.csw.mybatisSpringboot.config.page;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
public class PageResultUtil {
public static <T> PageResult<T> getPageResult(Page<T> param) {
PageResult<T> result = new PageResult<>();
if (param != null) {
result.setTotalPage(param.getPages());
result.setTotal(param.getTotal());
result.setItems(param.getRecords());
result.setPageNo((int) param.getCurrent());
result.setPageSize((int) param.getSize());
}
return result;
}
}
如果使用mapper内置方法的话可以使用重写覆盖即可
【int index = sb.toString().toLowerCase().indexOf("where");】修改了上面where的定位
service
@Override
public PageResult<User> selectAllByName(UserListDto dto) {
Page page = new Page<>(dto.getPageNo(), dto.getPageSize());
Page<User> userList = baseMapper.selectAllByName(page, dto);
PageResult<User> pageResult = PageResultUtil.getPageResult(userList);
return pageResult;
}
mapper
@SqlLimit(tableName = "user")
Page<User> selectAllByName(Page page, @Param("dto") UserListDto dto);
【以上为来自大佬的总结和进一步优化】
mybatis拦截器实现数据权限_mybatis数据权限-CSDN博客 https://blog.csdn.net/m0_71777195/article/details/131139654
java(springboot) mybatis 数据权限详细实现(图文) - 知乎 https://zhuanlan.zhihu.com/p/516113586?utm_id=0