模仿mybatis写一个sql解析工具,支持for if where标签

1.主要功能介绍:

sql语法基本和mybatis一致(兼容mybatis语法)

目前已经支持if where choose和for等

支持list参数: in #list#会转换成 in (?, ?, ?, ?)

关键:支持多条sql,可以使用$i作为第i条sql的返回结果用于其他sql中

where和if示例:

select * from tb_test
<where>
    <if test="aaa != null">
    and aaa = #aaa#
    </if>
    <if test="bbb == 'Test'">
    and bbb = #bbb#
    </if>
</where>

for语句示例:

insert into tb_user (name, age, address)
values 
<for list="list" item="user" join=",">
    (#user.name#, #user.age#, #user.address#)
</for>;

select * from tb_test
where id in
<for list="list" item="id" join="," start="(" end=")">
    #id#
</for>;

多条sql使用示例:

select name from tb_test
where id in
<for list="list" item="id" join="," start="(" end=")">
    #id#
</for>;

delete from tb_user
where name in #$1#;

##这里的$1是第一条sql的返回值

2.解析sql工具类:

package com.fly.tool.api.util;

import com.fly.tool.api.common.BaseException;
import com.fly.tool.api.db.SqlAndParams;
import org.dom4j.*;
import org.dom4j.tree.DefaultCDATA;
import org.dom4j.tree.DefaultText;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.fly.tool.api.common.Constant.INTERNAL_ERROR_CODE;
import static com.fly.tool.api.common.ErrorMessage.SQL_ERROR;
import static com.fly.tool.api.common.ErrorMessage.SQL_XML_ERROR;
import static com.fly.tool.api.common.SqlConstant.*;
import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE;
import static java.util.Objects.isNull;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

/**
 *
 * @author guoxiang
 * @version 1.0.0
 * @since 2021/6/22
 */
public class DynamicSqlParser {

    private static final Logger log = LoggerFactory.getLogger(com.fly.tool.api.util.DynamicSqlParser.class);

    private DynamicSqlParser() {
        //这是一个工具类
    }

    /**
     * 将数据库中的sql解析为可执行的sql语句
     *
     * @param sql    数据库中的sql
     * @param params 参数
     * @return sql list
     */
    public static List<String> generateSql(String sql, Map<String, Object> params) {
        //如果不是xml格式,先封装成xml,用于适配以往版本
        if (!XML_PATTERN.matcher(sql).matches()) {
            sql = DEFAULT_XML_TAG + sql + DEFAULT_XML_TAG_END;
        }

        Document document;
        try {
            document = DocumentHelper.parseText(sql);
        } catch (DocumentException e) {
            throw new BaseException(e, INTERNAL_ERROR_CODE, SQL_XML_ERROR);
        }

        String sqlString = parseXml(document.getRootElement(), params);
        if (ObjectUtils.isEmpty(sqlString)) {
            return Collections.emptyList();
        }

        //解析生成的sql可能包含多个sql
        String[] sqlArray = sqlString.split(";");
        return Stream.of(sqlArray).filter(ObjectUtils::notBlank).collect(toList());
    }


    /**
     * 处理节点
     *
     * @param node    node
     * @param context 参数上下文
     * @return sql
     */
    private static String parseXml(Node node, Map<String, Object> context) {
        //如果是普通sql || 如果是CDATA
        if (node instanceof DefaultText || node instanceof DefaultCDATA) {
            return node.getText();
        }

        //如果不认识,则返回空串
        if (!(node instanceof Element)) {
            return EMPTY_STRING;
        }

        Element element = (Element) node;
        String name = element.getName();

        //处理各种标签
        switch (name) {
            case WHERE:
                return handleWhere(element, context);
            case IF:
            case WHEN:
                return handleIf(element, context);
            case FOR:
                return handleFor(element, context);
            case SQL:
            case DYNAMIC_SQL:
            case OTHERWISE:
                return handleRoot(element, context);
            case CHOOSE:
                return handleChoose(element, context);
            case ONE_SQL:
            case MORE_SQL:
                return handleRoot(element, context) + SQL_DELIMITER;
            default:
                return EMPTY_STRING;
        }
    }

    /**
     * 处理根节点
     *
     * @param element 元素
     * @param context 上下文
     * @return sql
     */
    private static String handleRoot(Element element, Map<String, Object> context) {
        return element.content()
                .stream()
                .map(n -> parseXml(n, context))
                .collect(joining());
    }


    /**
     * 处理<if>标签
     *
     * @param element 元素
     * @param context 参数
     * @return sql
     */
    private static String handleIf(Element element, Map<String, Object> context) {
        Attribute test = element.attribute(TEST);
        Object value = OgnlUtils.parse(test.getValue(), context);

        if (FALSE.equals(value)) {
            return EMPTY_STRING;
        }

        return handleRoot(element, context);
    }

    /**
     * 处理where标签
     *
     * @param element 元素
     * @param context 参数
     * @return sql
     */
    private static String handleWhere(Element element, Map<String, Object> context) {

        String sql = handleRoot(element, context);

        if (ObjectUtils.isEmpty(sql)) {
            return EMPTY_STRING;
        }

        sql = sql.replaceAll(LINE, SPACE).trim();

        if (ObjectUtils.isEmpty(sql)) {
            return EMPTY_STRING;
        }

        //替换掉开头的and或者or
        return WHERE + SPACE + sql.replaceFirst(AND_OR, EMPTY_STRING);
    }

    /**
     * 处理for标签 语法兼容mybatis
     * 原理:将for标签解析成多个sql片段,将其中的#xxx#标签转换为对应的#list[i].xxx#
     * 在执行sql时,使用ognl表达式取值
     *
     * @param element 元素
     * @param context 上下文
     * @return sql
     */
    private static String handleFor(Element element, Map<String, Object> context) {
        String listName = getAttribute(element, LIST, COLLECTION);
        String itemName = getAttribute(element, ITEM);
        String start = getAttribute(element, START, OPEN);
        String end = getAttribute(element, END, CLOSE);
        String join = getAttribute(element, JOIN, SEPARATOR);

        Assert.notEmpty(listName, INTERNAL_ERROR_CODE, SQL_ERROR);

        itemName = ObjectUtils.isEmpty(itemName) ? ITEM : itemName;
        String sqlFragment = element.getText().replace(LINE, SPACE);

        Collection<Object> list = OgnlUtils.parse(listName, context);

        if (ObjectUtils.isEmpty(list)) {
            log.warn("- list: {} is empty!", listName);
            return start + end;
        }

        //生成sql循环
        String regex = "#" + itemName + "(\\.?\\S*)#";
        return IntStream.range(0, list.size())
                .boxed()
                .map(i -> sqlFragment.replaceAll(regex, "#" + listName + "[" + i + "]$1#"))
                .collect(joining(join, start, end));
    }

    private static String getAttribute(Element element, String... names) {
        return Stream.of(names)
                .map(element::attribute)
                .filter(Objects::nonNull)
                .findAny()
                .map(Attribute::getValue)
                .orElse(EMPTY_STRING);
    }

    /**
     * 处理choose标签
     *
     * @param element   element
     * @param context   context
     * @return          sql
     */
    private static String handleChoose(Element element, Map<String, Object> context) {
        List<Node> content = element.content();

        Optional<Node> whenNode = content.stream()
                .filter(node -> WHEN.equals(node.getName()))
                .filter(node -> testWhen(node, context))
                .findFirst();

        //如果存在满足条件地when,则删除其他的when和otherwise
        if (whenNode.isPresent()) {
            Node when = whenNode.get();
            content.removeIf(node -> WHEN.equals(node.getName()) && !node.equals(when));
            content.removeIf(node -> OTHERWISE.equals(node.getName()));
        }
        //否则删除所有的when,保留otherwise
        else {
            content.removeIf(node -> WHEN.equals(node.getName()));
        }

        //校验最多只有一个when或者other
        long count = content.stream()
                .filter(node -> WHEN.equals(node.getName()) || OTHERWISE.equals(node.getName()))
                .count();

        if (count > 1) {
            log.error("when tag or otherwise tag is more than one: {}", element);
            throw new BaseException(INTERNAL_ERROR_CODE, SQL_XML_ERROR);
        }

        return content.stream()
                .map(n -> parseXml(n, context))
                .collect(joining());
    }

    private static boolean testWhen(Node node, Map<String, Object> context) {
        Element element = (Element) node;
        Attribute test = element.attribute(TEST);
        Boolean result = OgnlUtils.parse(test.getValue(), context);

        return TRUE.equals(result);
    }


    /**
     * 解析sql语句为可执行的sql和相应的参数
     * 这部分为了避免使用正则表达式因而手动解析
     * 逻辑极其复杂普通人别tm乱动
     *
     * @param sql sql
     * @return 可执行sql: select * from tt where id = ?
     */
    public static SqlAndParams parseSql(String sql, Map<String, Object> context) {
        List<Object> params = new ArrayList<>();
        Assert.notEmpty(sql, INTERNAL_ERROR_CODE, SQL_ERROR);
        // search open token
        int start = sql.indexOf(OPEN_TOKEN);
        if (start == -1) {
            return new SqlAndParams(sql);
        }
        char[] src = sql.toCharArray();
        int offset = 0;
        final StringBuilder builder = new StringBuilder();
        StringBuilder expression = null;
        do {
            if (start > 0 && src[start - 1] == '\\') {
                builder.append(src, offset, start - offset - 1).append(OPEN_TOKEN);
                offset = start + OPEN_TOKEN.length();
            } else {
                if (expression == null) {
                    expression = new StringBuilder();
                } else {
                    expression.setLength(0);
                }
                builder.append(src, offset, start - offset);
                offset = start + OPEN_TOKEN.length();
                int end = sql.indexOf(CLOSE_TOKEN, offset);
                while (end > -1) {
                    if (end > offset && src[end - 1] == '\\') {
                        expression.append(src, offset, end - offset - 1).append(CLOSE_TOKEN);
                        offset = end + CLOSE_TOKEN.length();
                        end = sql.indexOf(CLOSE_TOKEN, offset);
                    } else {
                        expression.append(src, offset, end - offset);
                        break;
                    }
                }
                if (end == -1) {
                    builder.append(src, start, src.length - start);
                    offset = src.length;
                } else {
                    Collection<?> objects = handleExpression(expression, builder, context);
                    params.addAll(objects);
                    offset = end + CLOSE_TOKEN.length();
                }
            }
            start = sql.indexOf(OPEN_TOKEN, offset);
        } while (start > -1);
        if (offset < src.length) {
            builder.append(src, offset, src.length - offset);
        }

        return new SqlAndParams(builder.toString(), params);
    }

    private static Collection<?> handleExpression(StringBuilder expression, StringBuilder builder, Map<String, Object> context) {
        String key = expression.toString();
        Object value = context.get(key);
        if (isNull(value)) {
            value = OgnlUtils.parse(key, context);
        }
        //处理数组和集合
        if (value instanceof Collection) {
            Collection<?> collection = (Collection<?>) value;
            String replace = collection.stream().map(s -> "?")
                    .collect(Collectors.joining(",", "(", ")"));
            builder.append(replace);
            return collection;
        }
        //普通数据类型
        else {
            builder.append('?');
            return Collections.singletonList(value);
        }
    }

}

3.常量类:

    public static final String SQL = "sql";
    public static final String DYNAMIC_SQL = "DynamicSql";
    public static final String ONE_SQL = "OneSql";
    public static final String MORE_SQL = "MoreSql";
    public static final String EMPTY_STRING = "";
    public static final Pattern XML_PATTERN = Pattern.compile("^<[^/]\\S+?>.*</\\S+>$");
    public static final String DEFAULT_XML_TAG = "<sql>";
    public static final String DEFAULT_XML_TAG_END = "</sql>";
    public static final String LINE = "\n";
    public static final String SPACE = " ";
    public static final String SQL_DELIMITER = "; ";

    public static final String OPEN_TOKEN = "#";
    public static final String CLOSE_TOKEN = "#";


    public static final String WHERE = "where";
    public static final String WHERE_DEFAULT = "where 1=1";
    public static final String AND = "and";
    public static final String OR = "or";
    public static final String AND_OR = "^(and |AND |or |OR )";

    public static final String IF = "if";
    public static final String TEST = "test";

    public static final String FOR = "for";
    public static final String LIST = "list";
    public static final String COLLECTION = "collection";
    public static final String ITEM = "item";
    public static final String START = "start";
    public static final String OPEN = "open";
    public static final String END = "end";
    public static final String CLOSE = "close";
    public static final String JOIN = "join";
    public static final String SEPARATOR = "separator";

4. 使用教程

public class Test {
    public static void main(String[] args) throws Exception {

        String xml = "insert into tb_user (name, age, address)\n" +
                "values\n" +
                "<for collection=\"list\" item=\"user\" join=\",\">\n" +
                "    (#user.name#, #user.age#, #user.address#)\n" +
                "</for>";

        HashMap<String, Object> map = new HashMap<>();
        map.put("list", Arrays.asList(1, 2, 3, 4));

        List<String> sql = DynamicSqlParser.generateSql(xml, map);
        System.out.println(sql);
    }
}

运行后解析出来的:
insert into tb_user (name, age, address)
values
(#list[0].name#, #list[0].age#, #list[0].address#) ,
(#list[1].name#, #list[1].age#, #list[1].address#) ,
(#list[2].name#, #list[2].age#, #list[2].address#) ,
(#list[3].name#, #list[3].age#, #list[3].address#)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容