【解析SQL模板-2】calcite解析SQL

背景

获取到完整SQL后,需要解析SQL判断SQL里面的tables、是否存在select *

实现

引入依赖:

<dependency>
    <groupId>org.apache.calcite</groupId>
    <artifactId>calcite-core</artifactId>
    <version>1.30.0</version>
</dependency>

SQL经过calcite解析之后,得到一棵抽象语法树,也就是我们说的AST,这棵语法树是由不同的节点组成,节点称之为SqlNode,根据不同类型的dml、ddl得到不同的类型的SqlNode,例如select语句转换为SqlSelect,delete语句转换为SqlDelete,join语句转换为SqlJoin。

一个select语句包含from部分、where部分、select部分等,每一部分都表示一个SqlNode。SqlKind是一个枚举类型,包含了各种SqlNode类型:SqlSelect、SqlIdentifier、SqlLiteral等。SqlIdentifier表示标识符,例如表名称、字段名;SqlLiteral表示字面常量,一些具体的数字、字符。


SqlNode.png

SQL构建成SqlNode节点:

public enum ApiDatasourceType {
    CLICKHOUSE,
    MYSQL
    ;
}
@Slf4j
public class SqlUtils {
    public static SqlNode parseSql(ApiDatasourceType type, String sql) {
        SqlParser parser;
        switch (type) {
            case MYSQL:
                parser = SqlParser.create(sql, SqlParserConfig.mysqlConfig());
                break;
            case CLICKHOUSE:
            default:
                parser = SqlParser.create(sql, SqlParserConfig.defaultConfig());
        }
        try {
            return parser.parseQuery();
        } catch (SqlParseException e) {
            throw e;
        }
    }
}

用法一:获取表名、判断是否存在select *

public class SqlUtils {

    public static boolean hasSelectAll(SqlNode node) {
        if (Objects.isNull(node)) {
            throw new IllegalStateException("SqlNode is null");
        }
        if (node.getKind() == SqlKind.IDENTIFIER) {
            return Objects.equals(node.toString(), "*");
        } else if (node.getKind() == SqlKind.AS) {
            return hasSelectAll(((SqlBasicCall) node).getOperandList().get(0));
        } else if (node.getKind() == SqlKind.SELECT) {
            SqlSelect select = (SqlSelect) node;
            return select.getSelectList().stream()
                    .anyMatch(SqlUtils::hasSelectAll)
                    || hasSelectAll(select.getFrom())
                    || (Objects.nonNull(select.getWhere()) && hasSelectAllInWhere(select.getWhere()));
        } else if (node.getKind() == SqlKind.ORDER_BY) {
            SqlOrderBy orderByNode = (SqlOrderBy) node;
            SqlSelect select = (SqlSelect) orderByNode.getOperandList().get(0);
            return select.getSelectList().stream()
                    .anyMatch(SqlUtils::hasSelectAll) || hasSelectAll(select.getFrom());
        } else if (node.getKind() == SqlKind.JOIN) {
            SqlJoin joinNode = (SqlJoin) node;
            return hasSelectAll(joinNode.getLeft()) || hasSelectAll(joinNode.getRight());
        } else if (node.getKind() == SqlKind.UNION) {
            SqlBasicCall unionNode = (SqlBasicCall) node;
            return unionNode.getOperandList()
                    .stream()
                    .anyMatch(SqlUtils::hasSelectAll);
        }
        return false;
    }

    public static Set<String> findTables(SqlNode node) {
        Set<String> tables = new HashSet<>();
        findTables(tables, node);
        return tables;
    }

    private static void findTables(Set<String> foundTables, SqlNode node) {
        if (Objects.isNull(node)) {
            throw new IllegalStateException("SqlNode is null");
        }
        //判断是否是标识符
        if (node.getKind() == SqlKind.IDENTIFIER) {
            foundTables.add(node.toString());
        } else if (node.getKind() == SqlKind.AS) {
            findTables(foundTables, ((SqlBasicCall) node).getOperandList().get(0));
        } else if (node.getKind() == SqlKind.SELECT) {
            SqlSelect select = (SqlSelect) node;
            findTables(foundTables, select.getFrom());
            if (Objects.nonNull(select.getWhere())) {
                findTablesInWhere(foundTables, select.getWhere());
            }
        } else if (node.getKind() == SqlKind.ORDER_BY) {
            SqlOrderBy orderByNode = (SqlOrderBy) node;
            SqlSelect selectNode = (SqlSelect) orderByNode.getOperandList().get(0);
            findTables(foundTables, selectNode.getFrom());
        } else if (node.getKind() == SqlKind.JOIN) {
            SqlJoin joinNode = (SqlJoin) node;
            findTables(foundTables, joinNode.getLeft());
            findTables(foundTables, joinNode.getRight());
        } else if (node.getKind() == SqlKind.UNION) {
            SqlBasicCall unionNode = (SqlBasicCall) node;
            unionNode.getOperandList().forEach(n -> findTables(foundTables, n));
        } else {
            throw new IllegalStateException(String.format("Un support node type %s", node.getKind().toString()));
        }
    }

    private static boolean hasSelectAllInWhere(SqlNode node) {
        if (node.getKind() == SqlKind.SELECT || node.getKind() == SqlKind.ORDER_BY) {
            return hasSelectAll(node);
        }
        if (node instanceof SqlBasicCall) {
            SqlBasicCall call = (SqlBasicCall) node;
            return call.getOperandList()
                    .stream().anyMatch(SqlUtils::hasSelectAllInWhere);
        }
        return false;
    }

    private static void findTablesInWhere(Set<String> foundTables, SqlNode node) {
        if (node.getKind() == SqlKind.SELECT || node.getKind() == SqlKind.ORDER_BY) {
            findTables(foundTables, node);
        }
        if (node instanceof SqlBasicCall) {
            SqlBasicCall call = (SqlBasicCall) node;
            call.getOperandList().forEach(n -> findTablesInWhere(foundTables, n));
        }
    }
}

测试方法:

    public static void main(String[] args) {
        String sql = "select count(*) cnt from ( select * from tableA where event_time >= 1713369600 and event_time <= 1713369600 limit 1 ) a";
        SqlNode sqlNode = parseSql(ApiDatasourceType.MYSQL, sql);
        System.out.println(SqlUtils.findTables(sqlNode));
        System.out.println(SqlUtils.hasSelectAll(sqlNode));
    }

用法二:解析获取select count(*) 查询总数的SQL

@Slf4j
public class MybatisGenerator {
    private static final SqlNode COUNT_SQL =
            SqlSelectBuilder.as(SqlSelectBuilder.function("count", SqlSelectBuilder.star()), SqlSelectBuilder.identifier("count"));

    public static void main(String[] args) {
        String dsl =
                "SELECT * FROM tableA as ta LEFT JOIN tableB as tb ON ta.column_name=tb.column_name where ta.event_time >= 1713369600 and ta.event_time <= 1713369600 order by ta.id desc limit 100 ";
        System.out.println(dsl);
        //获取到SQL解析器
        SqlNode sqlNode = SqlUtils.parseSql(ApiDatasourceType.CLICKHOUSE, dsl);
        SqlNode node = sqlNode;
        SqlSelect totalSqlNode;
        if (node.getKind() == SqlKind.ORDER_BY) {
            totalSqlNode = SqlSelectBuilder.builder().select(COUNT_SQL).from(((SqlOrderBy) node).getOperandList().get(0)).buildSelect();
        } else if (node.getKind() == SqlKind.SELECT) {
            totalSqlNode = SqlSelectBuilder.builder().select(COUNT_SQL)
                    .from(SqlSelectBuilder.builder((SqlSelect) node)
                            .clearFetch()
                            .clearOffset()
                            .clearOrderBy()
                            .buildSelect())
                    .buildSelect();
        } else {
            throw new IllegalStateException("错误的SQL类型");
        }
    }
}

测试运行:

SELECT COUNT(*) AS `count` FROM (SELECT * FROM `tableA` AS `ta` LEFT JOIN `tableB` AS `tb` ON `ta`.`column_name` = `tb`.`column_name` WHERE `ta`.`event_time` >= 1713369600 AND `ta`.`event_time` <= 1713369600)

场景三:替换备用表名

@Slf4j
public class MybatisGenerator {

    public static Map<String, String> BAK_TABLES = new HashMap<>();

    //双表备份--替换表名
    static {
        BAK_TABLES.put("tableA", "tableA_bak");
        BAK_TABLES.put("tableB", "tableB_bak");
    }


    private static SqlNode rewriteTableName(SqlNode node, ProcessContext context) {
        if (Objects.isNull(node)) {
            return null;
        }
        return analysisSqlNode(node, context);
    }

    private static SqlNode analysisSqlNode(SqlNode node, ProcessContext context) {
        SqlNode resNode = node;
        if (SqlKind.SELECT == node.getKind()) {
            SqlSelect select = (SqlSelect) node;
            SqlNode resFrom;
            if (SqlKind.JOIN == select.getFrom().getKind()) {
                SqlJoin join = (SqlJoin) select.getFrom();
                SqlNode sqlNodeLeft = analysisSqlNode(join.getLeft(), context);
                SqlNode sqlNodeRight = analysisSqlNode(join.getRight(), context);
                join.setLeft(sqlNodeLeft);
                join.setRight(sqlNodeRight);
                resFrom = join;
            } else {
                resFrom = analysisSqlNode(select.getFrom(), context);
            }
            resNode = SqlSelectBuilder.builder(select)
                    .from(resFrom)
                    .buildSelect();

        } else if (SqlKind.ORDER_BY == node.getKind()) {
            SqlOrderBy orderBy = (SqlOrderBy) node;
            if (SqlKind.SELECT == orderBy.query.getKind()) {
                SqlSelect sqlSelect = SqlSelectBuilder.builder((SqlSelect) orderBy.query)
                        .from(analysisSqlNode(((SqlSelect) orderBy.query).getFrom(), context))
                        .buildSelect();
                resNode = SqlOrderByBuilder.builder(orderBy).query(sqlSelect).build();
            }
        } else if (SqlKind.JOIN == node.getKind()) {
            SqlJoin join = (SqlJoin) node;
            SqlNode sqlNodeLeft = analysisSqlNode(join.getLeft(), context);
            SqlNode sqlNodeRight = analysisSqlNode(join.getRight(), context);
            join.setLeft(sqlNodeLeft);
            join.setRight(sqlNodeRight);
            resNode = join;
        } else if (SqlKind.AS == node.getKind()) {
            SqlBasicCall as = (SqlBasicCall) node;
            SqlNode sqlNode = as.getOperandList().get(0);
            if (SqlKind.IDENTIFIER != sqlNode.getKind()) {
                SqlNode sqlNode1 = analysisSqlNode(sqlNode, context);
                resNode = new SqlBasicCall(as.getOperator(), Lists.newArrayList(sqlNode1, as.getOperandList().get(1)),
                        SqlParserPos.ZERO);
            } else {
                resNode = replaceTableNameConvert(node, context);
            }
        } else if (SqlKind.IDENTIFIER == node.getKind()) {
            resNode = replaceTableNameConvert(node, context);
        }
        return resNode;
    }

    private static SqlNode replaceTableNameConvert(SqlNode sourceNode, ProcessContext context) {
        if (Objects.isNull(sourceNode)) {
            return sourceNode;
        }
        SqlNode resNode = sourceNode;
        if (SqlKind.IDENTIFIER == sourceNode.getKind()) {
            //这里有两种情况:database为空或者有值
            String database = "";
            String oldTable = "";
            PhysicalTable physicalTable;
            String oldTableName = sourceNode.toString();
            String[] split = oldTableName.split("\\.");
            if (split.length == 2) {
                database = split[0] + ".";
                oldTable = split[1];
                physicalTable = context.getPhysicalTableMap().get(oldTable);
            } else {
                oldTable = split[0];
                physicalTable = context.getPhysicalTableMap().get(oldTable);
            }
            if (physicalTable != null && StringUtils.isNotBlank(physicalTable.getCurrentShardTableName())) {
                resNode = SqlSelectBuilder.as(SqlSelectBuilder.identifier(database + physicalTable.getCurrentShardTableName()),
                        SqlSelectBuilder.identifier(oldTableName));
            } else if (BAK_TABLES.containsKey(oldTable)) {
                resNode = SqlSelectBuilder.as(SqlSelectBuilder.identifier(database + BAK_TABLES.get(oldTable)),
                        SqlSelectBuilder.identifier(oldTableName));
            }

        } else if (SqlKind.AS == sourceNode.getKind()) {
            SqlBasicCall as = (SqlBasicCall) sourceNode;
            //这里有两种情况:database为空或者有值
            String database = "";
            String oldTable = "";
            PhysicalTable physicalTable;
            String oldAsTableName = as.getOperandList().get(1).toString();
            String[] split = as.getOperandList().get(0).toString().split("\\.");
            if (split.length == 2) {
                database = split[0] + ".";
                oldTable = split[1];
                physicalTable = context.getPhysicalTableMap().get(oldTable);
            } else {
                oldTable = split[0];
                physicalTable = context.getPhysicalTableMap().get(oldTable);
            }
            if (physicalTable != null && StringUtils.isNotBlank(physicalTable.getCurrentShardTableName())) {
                resNode = SqlSelectBuilder.as(SqlSelectBuilder.identifier(database
                                + physicalTable.getCurrentShardTableName()),
                        SqlSelectBuilder.identifier(oldAsTableName));
            } else if (BAK_TABLES.containsKey(oldTable)) {
                resNode = SqlSelectBuilder.as(SqlSelectBuilder.identifier(database
                                + BAK_TABLES.get(oldTable)),
                        SqlSelectBuilder.identifier(oldAsTableName));
            }
        }
        return resNode;
    }


    @Data
    public static class ProcessContext {
        private Map<String, PhysicalTable> physicalTableMap = new HashMap<>();
    }

    public static void main(String[] args) {
        String dsl =
                "SELECT * FROM tableA as ta LEFT JOIN tableB as tb ON ta.column_name=tb.column_name where ta.event_time >= 1713369600 and ta.event_time <= 1713369600 order by ta.id desc limit 100 ";
        System.out.println(dsl);
        //获取到SQL解析器
        SqlNode sqlNode = SqlUtils.parseSql(ApiDatasourceType.CLICKHOUSE, dsl);
        ProcessContext processContext = new ProcessContext();
        SqlNode sqlNode1 = rewriteTableName(sqlNode, processContext);
        String newSql1 = sqlNode1.toSqlString(ClickhouseSqlDialect.DEFAULT)
                .getSql()
                .replace('\n', ' ');
        System.out.println(newSql1);
    }  
}

场景四:解析SQL

解析结果:

@Getter
@Setter
@ToString
public class AnalysisResult {

    /**
     * fetch + limit 最大值
     */
    private long maxSum;

    /**
     * offset 最大
     */
    private long maxOffset;

    /**
     * fetch 最大
     */
    private long maxFetch;

    private boolean hasGroupBy;

    /**
     * 查询语句中是否有count(*)
     */
    private boolean hasCount;

    private Set<String> queryTables = new HashSet<>();

    private SqlNode sqlNode;

    // 真实访问的物理表信息
    private List<PhysicalTable> physicalTableLists = new ArrayList<>();

    public void updateLimit(long offset, long fetch) {
        long sum = fetch + offset;
        maxSum = Math.max(sum, maxSum);
        maxFetch = Math.max(fetch, maxFetch);
        maxOffset = Math.max(offset, maxOffset);
    }

    public void updateLimit(Pair<Long, Long> limit) {
        long offset = limit.getLeft();
        long fetch = limit.getRight();
        updateLimit(offset, fetch);
    }

    public void updateGroupby(boolean hasGroupBy) {
        this.hasGroupBy = hasGroupBy;
    }

    public void setSqlNode(SqlNode sqlNode) {
        this.sqlNode = sqlNode;
    }

    public void addTable(String table) {
        queryTables.add(table);
    }

    public Set<String> queryTables() {
        return new HashSet<>(queryTables);
    }

    public PhysicalTable queryTableByName(String tableName) {
        if (StringUtils.isEmpty(tableName) || CollectionUtils.isEmpty(physicalTableLists)) {
            return null;
        }
        return physicalTableLists.stream().collect(Collectors.toMap(PhysicalTable::getName, Function.identity())).get(tableName);
    }
}

处理流程:

@Slf4j
public class SqlAnalysis {

    private static String JDBC_URL_PATTERN =
            "jdbc:(?<type>[a-z]+)://(?<host>[a-zA-Z0-9-//.]+):(?<port>[0-9]+)/(?<database>[a-zA-Z0-9_]+)?";
    public static boolean hasLimit(SqlNode node) {
        return node.getKind() == SqlKind.ORDER_BY && getFetch((SqlOrderBy) node) != null;
    }

    public static Pair<Long, Long> getLimit(SqlOrderBy node) {
        return Optional.ofNullable(node)
                .map(n -> Pair.of(getOffsetValue(n), getFetchValue(n)))
                .orElseGet(() -> Pair.of(0L, 0L));
    }

    public static void process(AnalysisResult result, SqlNode node) {
        if (Objects.isNull(node)) {
            return;
        }
        if (node.getKind() == SqlKind.AS) {
            SqlNode sqlNode = ((SqlBasicCall) node).getOperandList().get(0);
            process(result, sqlNode);
            //获取到as前的字符
            if (sqlNode.getKind() == SqlKind.IDENTIFIER) {
                result.addTable(sqlNode.toString());
            }
        } else if (node.getKind() == SqlKind.SELECT) {
            SqlSelect select = (SqlSelect) node;
            select.getSelectList().forEach(n -> process(result, n));
            if (CollectionUtils.isNotEmpty(select.getGroup())) {
                result.updateGroupby(true);
            }
            processFromNode(result, select.getFrom());
            process(result, select.getWhere());
        } else if (node.getKind() == SqlKind.ORDER_BY) {
            SqlOrderBy orderByNode = (SqlOrderBy) node;
            result.updateLimit(SqlAnalysis.getLimit(orderByNode));
            SqlSelect selectNode = (SqlSelect) orderByNode.getOperandList().get(0);
            process(result, selectNode);
        } else if (node.getKind() == SqlKind.JOIN) {
            SqlJoin joinNode = (SqlJoin) node;
            process(result, joinNode.getLeft());
            process(result, joinNode.getRight());
        } else if (node.getKind() == SqlKind.UNION) {
            SqlBasicCall unionNode = (SqlBasicCall) node;
            unionNode.getOperandList().forEach(n -> process(result, n));
        } else if (node.getKind() == SqlKind.OTHER_FUNCTION) {
            if (((SqlBasicCall) node).getOperator().isName("count", true)) {
                result.setHasCount(true);
            }
        }
    }

    public static void processFromNode(AnalysisResult result, SqlNode node) {
        if (Objects.isNull(node)) {
            return;
        }
        SqlNode target = node;
        if (node.getKind() == SqlKind.AS) {
            target = ((SqlBasicCall) node).getOperandList().get(0);
        }
        if (target.getKind() == SqlKind.IDENTIFIER) {
            PhysicalTable physicalTable = result.queryTableByName(target.toString());
            if (physicalTable != null && physicalTable.getShardNum() > 0) {
                target = SqlSelectBuilder.identifier(target.toString());
            }
            result.addTable(target.toString());
        } else {
            process(result, node);
        }
    }

    private static SqlNode getFetch(SqlOrderBy orderBy) {
        return orderBy.getOperandList().get(3);
    }

    private static SqlNode getOffset(SqlOrderBy orderBy) {
        return orderBy.getOperandList().get(2);
    }

    private static SqlNodeList getOrderList(SqlOrderBy orderBy) {
        return (SqlNodeList) orderBy.getOperandList().get(1);
    }

    private static long getFetchValue(SqlOrderBy orderBy) {
        return Optional.ofNullable(orderBy.getOperandList().get(3))
                .filter(SqlNumericLiteral.class::isInstance)
                .map(SqlNumericLiteral.class::cast)
                .map(SqlLiteral::getValue)
                .filter(BigDecimal.class::isInstance)
                .map(BigDecimal.class::cast)
                .map(BigDecimal::longValue)
                .orElse(0L);

    }

    private static long getOffsetValue(SqlOrderBy orderBy) {
        return Optional.ofNullable(orderBy.getOperandList().get(2))
                .filter(SqlNumericLiteral.class::isInstance)
                .map(SqlNumericLiteral.class::cast)
                .map(SqlLiteral::getValue)
                .filter(BigDecimal.class::isInstance)
                .map(BigDecimal.class::cast)
                .map(BigDecimal::longValue)
                .orElse(0L);
    }

    public static IllegalStateException parseSqlParserException(SqlParseException ex, String sql) {
        String errorMessage = Optional.ofNullable(ex.getMessage())
                .map(m -> m.split("\n"))
                .map(m -> m[0])
                .orElse("");

        return new IllegalStateException(String.format("解析SQL「%s」失败,失败原因是「%s」", sql, errorMessage), ex);
    }
}

测试方法:

    public static void main(String[] args) {
        String dsl =
                "SELECT * FROM tableA as ta LEFT JOIN tableB as tb ON ta.column_name=tb.column_name where ta.event_time >= 1713369600 and ta.event_time <= 1713369600 order by ta.id desc limit 10,100 ";
        System.out.println(dsl);
        //获取到SQL解析器
        SqlNode sqlNode = SqlUtils.parseSql(ApiDatasourceType.CLICKHOUSE, dsl);
        AnalysisResult result = new AnalysisResult();
        SqlAnalysis.process(result, sqlNode);
        System.out.println(result);
    }

结论:

AnalysisResult(maxSum=110, maxOffset=10, maxFetch=100, hasGroupBy=false, hasCount=false, queryTables=[tableB, tableA], sqlNode=null, physicalTableLists=[])

参考文档

【Calcite源码学习】SqlNode方言转换

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容