模仿mybatis写一个sql解析工具,支持for if where标签

1.主要功能介绍:

sql语法基本和mybatis一致(兼容mybatis语法)

目前已经支持if where choose和for等

支持list参数: in #list#会转换成 in (?, ?, ?, ?)

关键:支持多条sql,可以使用$i作为第i条sql的返回结果用于其他sql中

where和if示例:

select * from tb_test
<where>
    <if test="aaa != null">
    and aaa = #aaa#
    </if>
    <if test="bbb == 'Test'">
    and bbb = #bbb#
    </if>
</where>

for语句示例:

insert into tb_user (name, age, address)
values 
<for list="list" item="user" join=",">
    (#user.name#, #user.age#, #user.address#)
</for>;

select * from tb_test
where id in
<for list="list" item="id" join="," start="(" end=")">
    #id#
</for>;

多条sql使用示例:

select name from tb_test
where id in
<for list="list" item="id" join="," start="(" end=")">
    #id#
</for>;

delete from tb_user
where name in #$1#;

##这里的$1是第一条sql的返回值

2.解析sql工具类:

package com.fly.tool.api.util;

import com.fly.tool.api.common.BaseException;
import com.fly.tool.api.db.SqlAndParams;
import org.dom4j.*;
import org.dom4j.tree.DefaultCDATA;
import org.dom4j.tree.DefaultText;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.fly.tool.api.common.Constant.INTERNAL_ERROR_CODE;
import static com.fly.tool.api.common.ErrorMessage.SQL_ERROR;
import static com.fly.tool.api.common.ErrorMessage.SQL_XML_ERROR;
import static com.fly.tool.api.common.SqlConstant.*;
import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE;
import static java.util.Objects.isNull;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

/**
 *
 * @author guoxiang
 * @version 1.0.0
 * @since 2021/6/22
 */
public class DynamicSqlParser {

    private static final Logger log = LoggerFactory.getLogger(com.fly.tool.api.util.DynamicSqlParser.class);

    private DynamicSqlParser() {
        //这是一个工具类
    }

    /**
     * 将数据库中的sql解析为可执行的sql语句
     *
     * @param sql    数据库中的sql
     * @param params 参数
     * @return sql list
     */
    public static List<String> generateSql(String sql, Map<String, Object> params) {
        //如果不是xml格式,先封装成xml,用于适配以往版本
        if (!XML_PATTERN.matcher(sql).matches()) {
            sql = DEFAULT_XML_TAG + sql + DEFAULT_XML_TAG_END;
        }

        Document document;
        try {
            document = DocumentHelper.parseText(sql);
        } catch (DocumentException e) {
            throw new BaseException(e, INTERNAL_ERROR_CODE, SQL_XML_ERROR);
        }

        String sqlString = parseXml(document.getRootElement(), params);
        if (ObjectUtils.isEmpty(sqlString)) {
            return Collections.emptyList();
        }

        //解析生成的sql可能包含多个sql
        String[] sqlArray = sqlString.split(";");
        return Stream.of(sqlArray).filter(ObjectUtils::notBlank).collect(toList());
    }


    /**
     * 处理节点
     *
     * @param node    node
     * @param context 参数上下文
     * @return sql
     */
    private static String parseXml(Node node, Map<String, Object> context) {
        //如果是普通sql || 如果是CDATA
        if (node instanceof DefaultText || node instanceof DefaultCDATA) {
            return node.getText();
        }

        //如果不认识,则返回空串
        if (!(node instanceof Element)) {
            return EMPTY_STRING;
        }

        Element element = (Element) node;
        String name = element.getName();

        //处理各种标签
        switch (name) {
            case WHERE:
                return handleWhere(element, context);
            case IF:
            case WHEN:
                return handleIf(element, context);
            case FOR:
                return handleFor(element, context);
            case SQL:
            case DYNAMIC_SQL:
            case OTHERWISE:
                return handleRoot(element, context);
            case CHOOSE:
                return handleChoose(element, context);
            case ONE_SQL:
            case MORE_SQL:
                return handleRoot(element, context) + SQL_DELIMITER;
            default:
                return EMPTY_STRING;
        }
    }

    /**
     * 处理根节点
     *
     * @param element 元素
     * @param context 上下文
     * @return sql
     */
    private static String handleRoot(Element element, Map<String, Object> context) {
        return element.content()
                .stream()
                .map(n -> parseXml(n, context))
                .collect(joining());
    }


    /**
     * 处理<if>标签
     *
     * @param element 元素
     * @param context 参数
     * @return sql
     */
    private static String handleIf(Element element, Map<String, Object> context) {
        Attribute test = element.attribute(TEST);
        Object value = OgnlUtils.parse(test.getValue(), context);

        if (FALSE.equals(value)) {
            return EMPTY_STRING;
        }

        return handleRoot(element, context);
    }

    /**
     * 处理where标签
     *
     * @param element 元素
     * @param context 参数
     * @return sql
     */
    private static String handleWhere(Element element, Map<String, Object> context) {

        String sql = handleRoot(element, context);

        if (ObjectUtils.isEmpty(sql)) {
            return EMPTY_STRING;
        }

        sql = sql.replaceAll(LINE, SPACE).trim();

        if (ObjectUtils.isEmpty(sql)) {
            return EMPTY_STRING;
        }

        //替换掉开头的and或者or
        return WHERE + SPACE + sql.replaceFirst(AND_OR, EMPTY_STRING);
    }

    /**
     * 处理for标签 语法兼容mybatis
     * 原理:将for标签解析成多个sql片段,将其中的#xxx#标签转换为对应的#list[i].xxx#
     * 在执行sql时,使用ognl表达式取值
     *
     * @param element 元素
     * @param context 上下文
     * @return sql
     */
    private static String handleFor(Element element, Map<String, Object> context) {
        String listName = getAttribute(element, LIST, COLLECTION);
        String itemName = getAttribute(element, ITEM);
        String start = getAttribute(element, START, OPEN);
        String end = getAttribute(element, END, CLOSE);
        String join = getAttribute(element, JOIN, SEPARATOR);

        Assert.notEmpty(listName, INTERNAL_ERROR_CODE, SQL_ERROR);

        itemName = ObjectUtils.isEmpty(itemName) ? ITEM : itemName;
        String sqlFragment = element.getText().replace(LINE, SPACE);

        Collection<Object> list = OgnlUtils.parse(listName, context);

        if (ObjectUtils.isEmpty(list)) {
            log.warn("- list: {} is empty!", listName);
            return start + end;
        }

        //生成sql循环
        String regex = "#" + itemName + "(\\.?\\S*)#";
        return IntStream.range(0, list.size())
                .boxed()
                .map(i -> sqlFragment.replaceAll(regex, "#" + listName + "[" + i + "]$1#"))
                .collect(joining(join, start, end));
    }

    private static String getAttribute(Element element, String... names) {
        return Stream.of(names)
                .map(element::attribute)
                .filter(Objects::nonNull)
                .findAny()
                .map(Attribute::getValue)
                .orElse(EMPTY_STRING);
    }

    /**
     * 处理choose标签
     *
     * @param element   element
     * @param context   context
     * @return          sql
     */
    private static String handleChoose(Element element, Map<String, Object> context) {
        List<Node> content = element.content();

        Optional<Node> whenNode = content.stream()
                .filter(node -> WHEN.equals(node.getName()))
                .filter(node -> testWhen(node, context))
                .findFirst();

        //如果存在满足条件地when,则删除其他的when和otherwise
        if (whenNode.isPresent()) {
            Node when = whenNode.get();
            content.removeIf(node -> WHEN.equals(node.getName()) && !node.equals(when));
            content.removeIf(node -> OTHERWISE.equals(node.getName()));
        }
        //否则删除所有的when,保留otherwise
        else {
            content.removeIf(node -> WHEN.equals(node.getName()));
        }

        //校验最多只有一个when或者other
        long count = content.stream()
                .filter(node -> WHEN.equals(node.getName()) || OTHERWISE.equals(node.getName()))
                .count();

        if (count > 1) {
            log.error("when tag or otherwise tag is more than one: {}", element);
            throw new BaseException(INTERNAL_ERROR_CODE, SQL_XML_ERROR);
        }

        return content.stream()
                .map(n -> parseXml(n, context))
                .collect(joining());
    }

    private static boolean testWhen(Node node, Map<String, Object> context) {
        Element element = (Element) node;
        Attribute test = element.attribute(TEST);
        Boolean result = OgnlUtils.parse(test.getValue(), context);

        return TRUE.equals(result);
    }


    /**
     * 解析sql语句为可执行的sql和相应的参数
     * 这部分为了避免使用正则表达式因而手动解析
     * 逻辑极其复杂普通人别tm乱动
     *
     * @param sql sql
     * @return 可执行sql: select * from tt where id = ?
     */
    public static SqlAndParams parseSql(String sql, Map<String, Object> context) {
        List<Object> params = new ArrayList<>();
        Assert.notEmpty(sql, INTERNAL_ERROR_CODE, SQL_ERROR);
        // search open token
        int start = sql.indexOf(OPEN_TOKEN);
        if (start == -1) {
            return new SqlAndParams(sql);
        }
        char[] src = sql.toCharArray();
        int offset = 0;
        final StringBuilder builder = new StringBuilder();
        StringBuilder expression = null;
        do {
            if (start > 0 && src[start - 1] == '\\') {
                builder.append(src, offset, start - offset - 1).append(OPEN_TOKEN);
                offset = start + OPEN_TOKEN.length();
            } else {
                if (expression == null) {
                    expression = new StringBuilder();
                } else {
                    expression.setLength(0);
                }
                builder.append(src, offset, start - offset);
                offset = start + OPEN_TOKEN.length();
                int end = sql.indexOf(CLOSE_TOKEN, offset);
                while (end > -1) {
                    if (end > offset && src[end - 1] == '\\') {
                        expression.append(src, offset, end - offset - 1).append(CLOSE_TOKEN);
                        offset = end + CLOSE_TOKEN.length();
                        end = sql.indexOf(CLOSE_TOKEN, offset);
                    } else {
                        expression.append(src, offset, end - offset);
                        break;
                    }
                }
                if (end == -1) {
                    builder.append(src, start, src.length - start);
                    offset = src.length;
                } else {
                    Collection<?> objects = handleExpression(expression, builder, context);
                    params.addAll(objects);
                    offset = end + CLOSE_TOKEN.length();
                }
            }
            start = sql.indexOf(OPEN_TOKEN, offset);
        } while (start > -1);
        if (offset < src.length) {
            builder.append(src, offset, src.length - offset);
        }

        return new SqlAndParams(builder.toString(), params);
    }

    private static Collection<?> handleExpression(StringBuilder expression, StringBuilder builder, Map<String, Object> context) {
        String key = expression.toString();
        Object value = context.get(key);
        if (isNull(value)) {
            value = OgnlUtils.parse(key, context);
        }
        //处理数组和集合
        if (value instanceof Collection) {
            Collection<?> collection = (Collection<?>) value;
            String replace = collection.stream().map(s -> "?")
                    .collect(Collectors.joining(",", "(", ")"));
            builder.append(replace);
            return collection;
        }
        //普通数据类型
        else {
            builder.append('?');
            return Collections.singletonList(value);
        }
    }

}

3.常量类:

    public static final String SQL = "sql";
    public static final String DYNAMIC_SQL = "DynamicSql";
    public static final String ONE_SQL = "OneSql";
    public static final String MORE_SQL = "MoreSql";
    public static final String EMPTY_STRING = "";
    public static final Pattern XML_PATTERN = Pattern.compile("^<[^/]\\S+?>.*</\\S+>$");
    public static final String DEFAULT_XML_TAG = "<sql>";
    public static final String DEFAULT_XML_TAG_END = "</sql>";
    public static final String LINE = "\n";
    public static final String SPACE = " ";
    public static final String SQL_DELIMITER = "; ";

    public static final String OPEN_TOKEN = "#";
    public static final String CLOSE_TOKEN = "#";


    public static final String WHERE = "where";
    public static final String WHERE_DEFAULT = "where 1=1";
    public static final String AND = "and";
    public static final String OR = "or";
    public static final String AND_OR = "^(and |AND |or |OR )";

    public static final String IF = "if";
    public static final String TEST = "test";

    public static final String FOR = "for";
    public static final String LIST = "list";
    public static final String COLLECTION = "collection";
    public static final String ITEM = "item";
    public static final String START = "start";
    public static final String OPEN = "open";
    public static final String END = "end";
    public static final String CLOSE = "close";
    public static final String JOIN = "join";
    public static final String SEPARATOR = "separator";

4. 使用教程

public class Test {
    public static void main(String[] args) throws Exception {

        String xml = "insert into tb_user (name, age, address)\n" +
                "values\n" +
                "<for collection=\"list\" item=\"user\" join=\",\">\n" +
                "    (#user.name#, #user.age#, #user.address#)\n" +
                "</for>";

        HashMap<String, Object> map = new HashMap<>();
        map.put("list", Arrays.asList(1, 2, 3, 4));

        List<String> sql = DynamicSqlParser.generateSql(xml, map);
        System.out.println(sql);
    }
}

运行后解析出来的:
insert into tb_user (name, age, address)
values
(#list[0].name#, #list[0].age#, #list[0].address#) ,
(#list[1].name#, #list[1].age#, #list[1].address#) ,
(#list[2].name#, #list[2].age#, #list[2].address#) ,
(#list[3].name#, #list[3].age#, #list[3].address#)

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,417评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,921评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,850评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,945评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,069评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,188评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,239评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,994评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,409评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,735评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,898评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,578评论 4 336
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,205评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,916评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,156评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,722评论 2 363
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,781评论 2 351

推荐阅读更多精彩内容