接下来,我们将讲解 druid pool 包以外的包解析,这次我们先从 wallFilter
开始说起,我们先来写一个 wallFilter 的 example,首先我们需要在配置文件中开启 wallFilter
,接着我们从配置开始说起,配置信息如下:
Spring:
datasource:
druid:
filter:
wall:
enabled: true
config:
select-where-alway-true-check: true
首先需要开启 wallFilter
,然后配置 config,这里配置了 select-where-alway-true-check: true
就是检查永真条件的 where 语句,除了以上的配置外,还包可以配置如下属性:
我们先来测试一下 select-where-alway-true-check: true
属性,我们的 Mybatis 的 Mapper 文件中配置了 wehere 1 = 1 , 这个条件,然后进行测试,会发现如下报错信息:
java.sql.SQLException: sql injection violation, dbType mysql, druid-version 1.2.8, not terminal sql, token WHEN : select
......
from TABLES
when 1 = 1
at com.alibaba.druid.wall.WallFilter.checkInternal(WallFilter.java:859) ~[druid-1.2.8.jar:1.2.8]
at com.alibaba.druid.wall.WallFilter.connection_prepareStatement(WallFilter.java:295) ~[druid-1.2.8.jar:1.2.8]
at com.alibaba.druid.filter.FilterChainImpl.connection_prepareStatement(FilterChainImpl.java:568) ~[druid-1.2.8.jar:1.2.8]
at com.alibaba.druid.filter.FilterAdapter.connection_prepareStatement(FilterAdapter.java:930) ~[druid-1.2.8.jar:1.2.8]
我们可以看到,这里会直接报错,SQL 注入异常,我们根据堆栈位置,找出 WallFilter
的入口位置, 如下:
@Override
public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
throws SQLException {
return chain.connection_prepareStatement(connection, sql);
}
这里我们之前有讲过,这里是责任链模式,这里会先加载所有的 Filter 然后每个 Filter 通过递归的方式调用,我们再来看一下 WallFilter
的执行方法:
@Override
public PreparedStatementProxy connection_prepareStatement(FilterChain chain, ConnectionProxy connection, String sql)
throws SQLException {
String dbType = connection.getDirectDataSource().getDbType();
WallContext context = WallContext.create(dbType);
try {
WallCheckResult result = checkInternal(sql);
context.setWallUpdateCheckItems(result.getUpdateCheckItems());
sql = result.getSql();
PreparedStatementProxy stmt = chain.connection_prepareStatement(connection, sql);
setSqlStatAttribute(stmt);
return stmt;
} finally {
WallContext.clearContext();
}
}
首先是根据 dbType 生成 WallContext
,这个步骤没有太复杂的程序,主要是将 dbType 设置到 WallContext
中, 接着调用 checkInternal
方法:
private WallCheckResult checkInternal(String sql) throws SQLException {
WallCheckResult checkResult = provider.check(sql);
List<Violation> violations = checkResult.getViolations();
if (violations.size() > 0) {
......
}
return checkResult;
}
其实主要是调用 provider
来检查,我们看一下其实这个 provider
是在 WallFilter
init 的时候进行初始化的,我们先看一下 init 方法:
case mysql:
case oceanbase:
case drds:
case mariadb:
case h2:
case presto:
case trino:
if (config == null) {
config = new WallConfig(MySqlWallProvider.DEFAULT_CONFIG_DIR);
}
provider = new MySqlWallProvider(config);
break;
...
这里传进去的就是我们之前配置的 WallFilter
相关的 config 配置信息,我们再来看一下检查的具体逻辑:
private WallCheckResult checkInternal(String sql) {
checkCount.incrementAndGet();
WallContext context = WallContext.current();
if (config.isDoPrivilegedAllow() && ispPrivileged()) {
WallCheckResult checkResult = new WallCheckResult();
checkResult.setSql(sql);
return checkResult;
}
// first step, check whiteList
boolean mulltiTenant = config.getTenantTablePattern() != null && config.getTenantTablePattern().length() > 0;
if (!mulltiTenant) {
WallCheckResult checkResult = checkWhiteAndBlackList(sql);
if (checkResult != null) {
checkResult.setSql(sql);
return checkResult;
}
}
hardCheckCount.incrementAndGet();
final List<Violation> violations = new ArrayList<Violation>();
List<SQLStatement> statementList = new ArrayList<SQLStatement>();
boolean syntaxError = false;
boolean endOfComment = false;
try {
SQLStatementParser parser = createParser(sql);
parser.getLexer().setCommentHandler(WallCommentHandler.instance);
if (!config.isCommentAllow()) {
parser.getLexer().setAllowComment(false); // deny comment
}
if (!config.isCompleteInsertValuesCheck()) {
parser.setParseCompleteValues(false);
parser.setParseValuesSize(config.getInsertValuesCheckSize());
}
parser.parseStatementList(statementList);
final Token lastToken = parser.getLexer().token();
if (lastToken != Token.EOF && config.isStrictSyntaxCheck()) {
violations.add(new IllegalSQLObjectViolation(ErrorCode.SYNTAX_ERROR, "not terminal sql, token "
+ lastToken, sql));
}
endOfComment = parser.getLexer().isEndOfComment();
} catch (NotAllowCommentException e) {
violations.add(new IllegalSQLObjectViolation(ErrorCode.COMMENT_STATEMENT_NOT_ALLOW, "comment not allow", sql));
incrementCommentDeniedCount();
} catch (ParserException e) {
syntaxErrorCount.incrementAndGet();
syntaxError = true;
if (config.isStrictSyntaxCheck()) {
violations.add(new SyntaxErrorViolation(e, sql));
}
} catch (Exception e) {
if (config.isStrictSyntaxCheck()) {
violations.add(new SyntaxErrorViolation(e, sql));
}
}
if (statementList.size() > 1 && !config.isMultiStatementAllow()) {
violations.add(new IllegalSQLObjectViolation(ErrorCode.MULTI_STATEMENT, "multi-statement not allow", sql));
}
WallVisitor visitor = createWallVisitor();
visitor.setSqlEndOfComment(endOfComment);
if (statementList.size() > 0) {
boolean lastIsHint = false;
for (int i=0; i<statementList.size(); i++) {
SQLStatement stmt = statementList.get(i);
if ((i == 0 || lastIsHint) && stmt instanceof MySqlHintStatement) {
lastIsHint = true;
continue;
}
try {
stmt.accept(visitor);
} catch (ParserException e) {
violations.add(new SyntaxErrorViolation(e, sql));
}
}
}
if (visitor.getViolations().size() > 0) {
violations.addAll(visitor.getViolations());
}
Map<String, WallSqlTableStat> tableStat = context.getTableStats();
boolean updateCheckHandlerEnable = false;
{
WallUpdateCheckHandler updateCheckHandler = config.getUpdateCheckHandler();
if (updateCheckHandler != null) {
for (SQLStatement stmt : statementList) {
if (stmt instanceof SQLUpdateStatement) {
SQLUpdateStatement updateStmt = (SQLUpdateStatement) stmt;
SQLName table = updateStmt.getTableName();
if (table != null) {
String tableName = table.getSimpleName();
Set<String> updateCheckColumns = config.getUpdateCheckTable(tableName);
if (updateCheckColumns != null && updateCheckColumns.size() > 0) {
updateCheckHandlerEnable = true;
break;
}
}
}
}
}
}
WallSqlStat sqlStat = null;
if (violations.size() > 0) {
violationCount.incrementAndGet();
if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
sqlStat = addBlackSql(sql, tableStat, context.getFunctionStats(), violations, syntaxError);
}
} else {
if ((!updateCheckHandlerEnable) && sql.length() < MAX_SQL_LENGTH) {
boolean selectLimit = false;
if (config.getSelectLimit() > 0) {
for (SQLStatement stmt : statementList) {
if (stmt instanceof SQLSelectStatement) {
selectLimit = true;
break;
}
}
}
if (!selectLimit) {
sqlStat = addWhiteSql(sql, tableStat, context.getFunctionStats(), syntaxError);
}
}
}
if(sqlStat == null && updateCheckHandlerEnable){
sqlStat = new WallSqlStat(tableStat, context.getFunctionStats(), violations, syntaxError);
}
Map<String, WallSqlTableStat> tableStats = null;
Map<String, WallSqlFunctionStat> functionStats = null;
if (context != null) {
tableStats = context.getTableStats();
functionStats = context.getFunctionStats();
recordStats(tableStats, functionStats);
}
WallCheckResult result;
if (sqlStat != null) {
context.setSqlStat(sqlStat);
result = new WallCheckResult(sqlStat, statementList);
} else {
result = new WallCheckResult(null, violations, tableStats, functionStats, statementList, syntaxError);
}
String resultSql;
if (visitor.isSqlModified()) {
resultSql = SQLUtils.toSQLString(statementList, dbType);
} else {
resultSql = sql;
}
result.setSql(resultSql);
result.setUpdateCheckItems(visitor.getUpdateCheckItems());
return result;
}
主要做了以下几个事情:
1、检查这个 SQL 是否在白名单中,假如是就直接返回结果。
2、对 SQL 进行解析,生成 SQLStatement
列表,因为可能存在复合语句。
3、调用 SQLStatement
的 accept
方法,将 config 生成的 WallVisitor
放进去,然后检查是否会抛出异常,假如会,就代表存在语法错误,记录到 Result 中。