背景
获取到完整SQL后,需要解析SQL判断SQL里面的tables
、是否存在select *
。
实现
引入依赖:
<dependency>
<groupId>org.apache.calcite</groupId>
<artifactId>calcite-core</artifactId>
<version>1.30.0</version>
</dependency>
SQL经过calcite解析之后,得到一棵抽象语法树,也就是我们说的AST,这棵语法树是由不同的节点组成,节点称之为SqlNode,根据不同类型的dml、ddl得到不同的类型的SqlNode,例如select语句转换为SqlSelect,delete语句转换为SqlDelete,join语句转换为SqlJoin。
一个select语句包含from部分、where部分、select部分等,每一部分都表示一个SqlNode。SqlKind是一个枚举类型,包含了各种SqlNode类型:SqlSelect、SqlIdentifier、SqlLiteral等。SqlIdentifier表示标识符,例如表名称、字段名;SqlLiteral表示字面常量,一些具体的数字、字符。
SqlNode.png
SQL构建成SqlNode节点:
public enum ApiDatasourceType {
CLICKHOUSE,
MYSQL
;
}
@Slf4j
public class SqlUtils {
public static SqlNode parseSql(ApiDatasourceType type, String sql) {
SqlParser parser;
switch (type) {
case MYSQL:
parser = SqlParser.create(sql, SqlParserConfig.mysqlConfig());
break;
case CLICKHOUSE:
default:
parser = SqlParser.create(sql, SqlParserConfig.defaultConfig());
}
try {
return parser.parseQuery();
} catch (SqlParseException e) {
throw e;
}
}
}
用法一:获取表名、判断是否存在select *
public class SqlUtils {
public static boolean hasSelectAll(SqlNode node) {
if (Objects.isNull(node)) {
throw new IllegalStateException("SqlNode is null");
}
if (node.getKind() == SqlKind.IDENTIFIER) {
return Objects.equals(node.toString(), "*");
} else if (node.getKind() == SqlKind.AS) {
return hasSelectAll(((SqlBasicCall) node).getOperandList().get(0));
} else if (node.getKind() == SqlKind.SELECT) {
SqlSelect select = (SqlSelect) node;
return select.getSelectList().stream()
.anyMatch(SqlUtils::hasSelectAll)
|| hasSelectAll(select.getFrom())
|| (Objects.nonNull(select.getWhere()) && hasSelectAllInWhere(select.getWhere()));
} else if (node.getKind() == SqlKind.ORDER_BY) {
SqlOrderBy orderByNode = (SqlOrderBy) node;
SqlSelect select = (SqlSelect) orderByNode.getOperandList().get(0);
return select.getSelectList().stream()
.anyMatch(SqlUtils::hasSelectAll) || hasSelectAll(select.getFrom());
} else if (node.getKind() == SqlKind.JOIN) {
SqlJoin joinNode = (SqlJoin) node;
return hasSelectAll(joinNode.getLeft()) || hasSelectAll(joinNode.getRight());
} else if (node.getKind() == SqlKind.UNION) {
SqlBasicCall unionNode = (SqlBasicCall) node;
return unionNode.getOperandList()
.stream()
.anyMatch(SqlUtils::hasSelectAll);
}
return false;
}
public static Set<String> findTables(SqlNode node) {
Set<String> tables = new HashSet<>();
findTables(tables, node);
return tables;
}
private static void findTables(Set<String> foundTables, SqlNode node) {
if (Objects.isNull(node)) {
throw new IllegalStateException("SqlNode is null");
}
//判断是否是标识符
if (node.getKind() == SqlKind.IDENTIFIER) {
foundTables.add(node.toString());
} else if (node.getKind() == SqlKind.AS) {
findTables(foundTables, ((SqlBasicCall) node).getOperandList().get(0));
} else if (node.getKind() == SqlKind.SELECT) {
SqlSelect select = (SqlSelect) node;
findTables(foundTables, select.getFrom());
if (Objects.nonNull(select.getWhere())) {
findTablesInWhere(foundTables, select.getWhere());
}
} else if (node.getKind() == SqlKind.ORDER_BY) {
SqlOrderBy orderByNode = (SqlOrderBy) node;
SqlSelect selectNode = (SqlSelect) orderByNode.getOperandList().get(0);
findTables(foundTables, selectNode.getFrom());
} else if (node.getKind() == SqlKind.JOIN) {
SqlJoin joinNode = (SqlJoin) node;
findTables(foundTables, joinNode.getLeft());
findTables(foundTables, joinNode.getRight());
} else if (node.getKind() == SqlKind.UNION) {
SqlBasicCall unionNode = (SqlBasicCall) node;
unionNode.getOperandList().forEach(n -> findTables(foundTables, n));
} else {
throw new IllegalStateException(String.format("Un support node type %s", node.getKind().toString()));
}
}
private static boolean hasSelectAllInWhere(SqlNode node) {
if (node.getKind() == SqlKind.SELECT || node.getKind() == SqlKind.ORDER_BY) {
return hasSelectAll(node);
}
if (node instanceof SqlBasicCall) {
SqlBasicCall call = (SqlBasicCall) node;
return call.getOperandList()
.stream().anyMatch(SqlUtils::hasSelectAllInWhere);
}
return false;
}
private static void findTablesInWhere(Set<String> foundTables, SqlNode node) {
if (node.getKind() == SqlKind.SELECT || node.getKind() == SqlKind.ORDER_BY) {
findTables(foundTables, node);
}
if (node instanceof SqlBasicCall) {
SqlBasicCall call = (SqlBasicCall) node;
call.getOperandList().forEach(n -> findTablesInWhere(foundTables, n));
}
}
}
测试方法:
public static void main(String[] args) {
String sql = "select count(*) cnt from ( select * from tableA where event_time >= 1713369600 and event_time <= 1713369600 limit 1 ) a";
SqlNode sqlNode = parseSql(ApiDatasourceType.MYSQL, sql);
System.out.println(SqlUtils.findTables(sqlNode));
System.out.println(SqlUtils.hasSelectAll(sqlNode));
}
用法二:解析获取select count(*) 查询总数的SQL
@Slf4j
public class MybatisGenerator {
private static final SqlNode COUNT_SQL =
SqlSelectBuilder.as(SqlSelectBuilder.function("count", SqlSelectBuilder.star()), SqlSelectBuilder.identifier("count"));
public static void main(String[] args) {
String dsl =
"SELECT * FROM tableA as ta LEFT JOIN tableB as tb ON ta.column_name=tb.column_name where ta.event_time >= 1713369600 and ta.event_time <= 1713369600 order by ta.id desc limit 100 ";
System.out.println(dsl);
//获取到SQL解析器
SqlNode sqlNode = SqlUtils.parseSql(ApiDatasourceType.CLICKHOUSE, dsl);
SqlNode node = sqlNode;
SqlSelect totalSqlNode;
if (node.getKind() == SqlKind.ORDER_BY) {
totalSqlNode = SqlSelectBuilder.builder().select(COUNT_SQL).from(((SqlOrderBy) node).getOperandList().get(0)).buildSelect();
} else if (node.getKind() == SqlKind.SELECT) {
totalSqlNode = SqlSelectBuilder.builder().select(COUNT_SQL)
.from(SqlSelectBuilder.builder((SqlSelect) node)
.clearFetch()
.clearOffset()
.clearOrderBy()
.buildSelect())
.buildSelect();
} else {
throw new IllegalStateException("错误的SQL类型");
}
}
}
测试运行:
SELECT COUNT(*) AS `count` FROM (SELECT * FROM `tableA` AS `ta` LEFT JOIN `tableB` AS `tb` ON `ta`.`column_name` = `tb`.`column_name` WHERE `ta`.`event_time` >= 1713369600 AND `ta`.`event_time` <= 1713369600)
场景三:替换备用表名
@Slf4j
public class MybatisGenerator {
public static Map<String, String> BAK_TABLES = new HashMap<>();
//双表备份--替换表名
static {
BAK_TABLES.put("tableA", "tableA_bak");
BAK_TABLES.put("tableB", "tableB_bak");
}
private static SqlNode rewriteTableName(SqlNode node, ProcessContext context) {
if (Objects.isNull(node)) {
return null;
}
return analysisSqlNode(node, context);
}
private static SqlNode analysisSqlNode(SqlNode node, ProcessContext context) {
SqlNode resNode = node;
if (SqlKind.SELECT == node.getKind()) {
SqlSelect select = (SqlSelect) node;
SqlNode resFrom;
if (SqlKind.JOIN == select.getFrom().getKind()) {
SqlJoin join = (SqlJoin) select.getFrom();
SqlNode sqlNodeLeft = analysisSqlNode(join.getLeft(), context);
SqlNode sqlNodeRight = analysisSqlNode(join.getRight(), context);
join.setLeft(sqlNodeLeft);
join.setRight(sqlNodeRight);
resFrom = join;
} else {
resFrom = analysisSqlNode(select.getFrom(), context);
}
resNode = SqlSelectBuilder.builder(select)
.from(resFrom)
.buildSelect();
} else if (SqlKind.ORDER_BY == node.getKind()) {
SqlOrderBy orderBy = (SqlOrderBy) node;
if (SqlKind.SELECT == orderBy.query.getKind()) {
SqlSelect sqlSelect = SqlSelectBuilder.builder((SqlSelect) orderBy.query)
.from(analysisSqlNode(((SqlSelect) orderBy.query).getFrom(), context))
.buildSelect();
resNode = SqlOrderByBuilder.builder(orderBy).query(sqlSelect).build();
}
} else if (SqlKind.JOIN == node.getKind()) {
SqlJoin join = (SqlJoin) node;
SqlNode sqlNodeLeft = analysisSqlNode(join.getLeft(), context);
SqlNode sqlNodeRight = analysisSqlNode(join.getRight(), context);
join.setLeft(sqlNodeLeft);
join.setRight(sqlNodeRight);
resNode = join;
} else if (SqlKind.AS == node.getKind()) {
SqlBasicCall as = (SqlBasicCall) node;
SqlNode sqlNode = as.getOperandList().get(0);
if (SqlKind.IDENTIFIER != sqlNode.getKind()) {
SqlNode sqlNode1 = analysisSqlNode(sqlNode, context);
resNode = new SqlBasicCall(as.getOperator(), Lists.newArrayList(sqlNode1, as.getOperandList().get(1)),
SqlParserPos.ZERO);
} else {
resNode = replaceTableNameConvert(node, context);
}
} else if (SqlKind.IDENTIFIER == node.getKind()) {
resNode = replaceTableNameConvert(node, context);
}
return resNode;
}
private static SqlNode replaceTableNameConvert(SqlNode sourceNode, ProcessContext context) {
if (Objects.isNull(sourceNode)) {
return sourceNode;
}
SqlNode resNode = sourceNode;
if (SqlKind.IDENTIFIER == sourceNode.getKind()) {
//这里有两种情况:database为空或者有值
String database = "";
String oldTable = "";
PhysicalTable physicalTable;
String oldTableName = sourceNode.toString();
String[] split = oldTableName.split("\\.");
if (split.length == 2) {
database = split[0] + ".";
oldTable = split[1];
physicalTable = context.getPhysicalTableMap().get(oldTable);
} else {
oldTable = split[0];
physicalTable = context.getPhysicalTableMap().get(oldTable);
}
if (physicalTable != null && StringUtils.isNotBlank(physicalTable.getCurrentShardTableName())) {
resNode = SqlSelectBuilder.as(SqlSelectBuilder.identifier(database + physicalTable.getCurrentShardTableName()),
SqlSelectBuilder.identifier(oldTableName));
} else if (BAK_TABLES.containsKey(oldTable)) {
resNode = SqlSelectBuilder.as(SqlSelectBuilder.identifier(database + BAK_TABLES.get(oldTable)),
SqlSelectBuilder.identifier(oldTableName));
}
} else if (SqlKind.AS == sourceNode.getKind()) {
SqlBasicCall as = (SqlBasicCall) sourceNode;
//这里有两种情况:database为空或者有值
String database = "";
String oldTable = "";
PhysicalTable physicalTable;
String oldAsTableName = as.getOperandList().get(1).toString();
String[] split = as.getOperandList().get(0).toString().split("\\.");
if (split.length == 2) {
database = split[0] + ".";
oldTable = split[1];
physicalTable = context.getPhysicalTableMap().get(oldTable);
} else {
oldTable = split[0];
physicalTable = context.getPhysicalTableMap().get(oldTable);
}
if (physicalTable != null && StringUtils.isNotBlank(physicalTable.getCurrentShardTableName())) {
resNode = SqlSelectBuilder.as(SqlSelectBuilder.identifier(database
+ physicalTable.getCurrentShardTableName()),
SqlSelectBuilder.identifier(oldAsTableName));
} else if (BAK_TABLES.containsKey(oldTable)) {
resNode = SqlSelectBuilder.as(SqlSelectBuilder.identifier(database
+ BAK_TABLES.get(oldTable)),
SqlSelectBuilder.identifier(oldAsTableName));
}
}
return resNode;
}
@Data
public static class ProcessContext {
private Map<String, PhysicalTable> physicalTableMap = new HashMap<>();
}
public static void main(String[] args) {
String dsl =
"SELECT * FROM tableA as ta LEFT JOIN tableB as tb ON ta.column_name=tb.column_name where ta.event_time >= 1713369600 and ta.event_time <= 1713369600 order by ta.id desc limit 100 ";
System.out.println(dsl);
//获取到SQL解析器
SqlNode sqlNode = SqlUtils.parseSql(ApiDatasourceType.CLICKHOUSE, dsl);
ProcessContext processContext = new ProcessContext();
SqlNode sqlNode1 = rewriteTableName(sqlNode, processContext);
String newSql1 = sqlNode1.toSqlString(ClickhouseSqlDialect.DEFAULT)
.getSql()
.replace('\n', ' ');
System.out.println(newSql1);
}
}
场景四:解析SQL
解析结果:
@Getter
@Setter
@ToString
public class AnalysisResult {
/**
* fetch + limit 最大值
*/
private long maxSum;
/**
* offset 最大
*/
private long maxOffset;
/**
* fetch 最大
*/
private long maxFetch;
private boolean hasGroupBy;
/**
* 查询语句中是否有count(*)
*/
private boolean hasCount;
private Set<String> queryTables = new HashSet<>();
private SqlNode sqlNode;
// 真实访问的物理表信息
private List<PhysicalTable> physicalTableLists = new ArrayList<>();
public void updateLimit(long offset, long fetch) {
long sum = fetch + offset;
maxSum = Math.max(sum, maxSum);
maxFetch = Math.max(fetch, maxFetch);
maxOffset = Math.max(offset, maxOffset);
}
public void updateLimit(Pair<Long, Long> limit) {
long offset = limit.getLeft();
long fetch = limit.getRight();
updateLimit(offset, fetch);
}
public void updateGroupby(boolean hasGroupBy) {
this.hasGroupBy = hasGroupBy;
}
public void setSqlNode(SqlNode sqlNode) {
this.sqlNode = sqlNode;
}
public void addTable(String table) {
queryTables.add(table);
}
public Set<String> queryTables() {
return new HashSet<>(queryTables);
}
public PhysicalTable queryTableByName(String tableName) {
if (StringUtils.isEmpty(tableName) || CollectionUtils.isEmpty(physicalTableLists)) {
return null;
}
return physicalTableLists.stream().collect(Collectors.toMap(PhysicalTable::getName, Function.identity())).get(tableName);
}
}
处理流程:
@Slf4j
public class SqlAnalysis {
private static String JDBC_URL_PATTERN =
"jdbc:(?<type>[a-z]+)://(?<host>[a-zA-Z0-9-//.]+):(?<port>[0-9]+)/(?<database>[a-zA-Z0-9_]+)?";
public static boolean hasLimit(SqlNode node) {
return node.getKind() == SqlKind.ORDER_BY && getFetch((SqlOrderBy) node) != null;
}
public static Pair<Long, Long> getLimit(SqlOrderBy node) {
return Optional.ofNullable(node)
.map(n -> Pair.of(getOffsetValue(n), getFetchValue(n)))
.orElseGet(() -> Pair.of(0L, 0L));
}
public static void process(AnalysisResult result, SqlNode node) {
if (Objects.isNull(node)) {
return;
}
if (node.getKind() == SqlKind.AS) {
SqlNode sqlNode = ((SqlBasicCall) node).getOperandList().get(0);
process(result, sqlNode);
//获取到as前的字符
if (sqlNode.getKind() == SqlKind.IDENTIFIER) {
result.addTable(sqlNode.toString());
}
} else if (node.getKind() == SqlKind.SELECT) {
SqlSelect select = (SqlSelect) node;
select.getSelectList().forEach(n -> process(result, n));
if (CollectionUtils.isNotEmpty(select.getGroup())) {
result.updateGroupby(true);
}
processFromNode(result, select.getFrom());
process(result, select.getWhere());
} else if (node.getKind() == SqlKind.ORDER_BY) {
SqlOrderBy orderByNode = (SqlOrderBy) node;
result.updateLimit(SqlAnalysis.getLimit(orderByNode));
SqlSelect selectNode = (SqlSelect) orderByNode.getOperandList().get(0);
process(result, selectNode);
} else if (node.getKind() == SqlKind.JOIN) {
SqlJoin joinNode = (SqlJoin) node;
process(result, joinNode.getLeft());
process(result, joinNode.getRight());
} else if (node.getKind() == SqlKind.UNION) {
SqlBasicCall unionNode = (SqlBasicCall) node;
unionNode.getOperandList().forEach(n -> process(result, n));
} else if (node.getKind() == SqlKind.OTHER_FUNCTION) {
if (((SqlBasicCall) node).getOperator().isName("count", true)) {
result.setHasCount(true);
}
}
}
public static void processFromNode(AnalysisResult result, SqlNode node) {
if (Objects.isNull(node)) {
return;
}
SqlNode target = node;
if (node.getKind() == SqlKind.AS) {
target = ((SqlBasicCall) node).getOperandList().get(0);
}
if (target.getKind() == SqlKind.IDENTIFIER) {
PhysicalTable physicalTable = result.queryTableByName(target.toString());
if (physicalTable != null && physicalTable.getShardNum() > 0) {
target = SqlSelectBuilder.identifier(target.toString());
}
result.addTable(target.toString());
} else {
process(result, node);
}
}
private static SqlNode getFetch(SqlOrderBy orderBy) {
return orderBy.getOperandList().get(3);
}
private static SqlNode getOffset(SqlOrderBy orderBy) {
return orderBy.getOperandList().get(2);
}
private static SqlNodeList getOrderList(SqlOrderBy orderBy) {
return (SqlNodeList) orderBy.getOperandList().get(1);
}
private static long getFetchValue(SqlOrderBy orderBy) {
return Optional.ofNullable(orderBy.getOperandList().get(3))
.filter(SqlNumericLiteral.class::isInstance)
.map(SqlNumericLiteral.class::cast)
.map(SqlLiteral::getValue)
.filter(BigDecimal.class::isInstance)
.map(BigDecimal.class::cast)
.map(BigDecimal::longValue)
.orElse(0L);
}
private static long getOffsetValue(SqlOrderBy orderBy) {
return Optional.ofNullable(orderBy.getOperandList().get(2))
.filter(SqlNumericLiteral.class::isInstance)
.map(SqlNumericLiteral.class::cast)
.map(SqlLiteral::getValue)
.filter(BigDecimal.class::isInstance)
.map(BigDecimal.class::cast)
.map(BigDecimal::longValue)
.orElse(0L);
}
public static IllegalStateException parseSqlParserException(SqlParseException ex, String sql) {
String errorMessage = Optional.ofNullable(ex.getMessage())
.map(m -> m.split("\n"))
.map(m -> m[0])
.orElse("");
return new IllegalStateException(String.format("解析SQL「%s」失败,失败原因是「%s」", sql, errorMessage), ex);
}
}
测试方法:
public static void main(String[] args) {
String dsl =
"SELECT * FROM tableA as ta LEFT JOIN tableB as tb ON ta.column_name=tb.column_name where ta.event_time >= 1713369600 and ta.event_time <= 1713369600 order by ta.id desc limit 10,100 ";
System.out.println(dsl);
//获取到SQL解析器
SqlNode sqlNode = SqlUtils.parseSql(ApiDatasourceType.CLICKHOUSE, dsl);
AnalysisResult result = new AnalysisResult();
SqlAnalysis.process(result, sqlNode);
System.out.println(result);
}
结论:
AnalysisResult(maxSum=110, maxOffset=10, maxFetch=100, hasGroupBy=false, hasCount=false, queryTables=[tableB, tableA], sqlNode=null, physicalTableLists=[])