Skip to content

Latest commit

 

History

History
598 lines (510 loc) · 22.1 KB

README.md

File metadata and controls

598 lines (510 loc) · 22.1 KB

基础

JDBC

JVM

Servlet

多线程与并发

新特性

实践

网络编程

集合

package com.sky.common.utils;

import com.sky.common.annotation.IndexExpression;
import com.sky.common.exception.CustomException;
import org.apache.commons.collections.CollectionUtils;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * 检索表达式的 Sql 解析器

 可检索字段:

 CO=运维单位,ST=变电站,GIS=是否GIS站,TP=设备类型,VL=电压等级,UT=间隔单元,NM=设备名称,MF=生产厂家,MD=型号,OD=投运日期,PD=出厂日期,OY=投运年份,OA=投运年限,EA=设备增加方式,

 示例:

 1)CO=平湖 and TP=主变压器 and (MF%TOSHIBA + 东芝) 可以检索到运维单位为“平湖”并且设备类型为“主变压器”并且生产厂家包含“TOSHIBA”或“东芝”的所有设备;

 2)ST=瓦山变 and VL>=35 可以检索到“瓦山变”35kV及以上的所有设备;

 3)MD%P00HXG - (MF=ABB) and OA>=5 可检索型号中包含“P00HXG”非ABB“生产的,投运年限不低于5年的相关设备。

 * --> and
 + --> or
 - --> not

 * @author xuguozong
 */
public final class ExpressionSqlParser {

    private IndexExpression i;

    public ExpressionSqlParser() {}

    /** 需要解析成参数占位符的形式,防止 sql 注入 -- #{} */
    public ParseResult parse(Class<?> model, String source) throws NoSuchMethodException,
            InvocationTargetException, IllegalAccessException {
        return parse(model, source, ExpressionNode::parse);
    }

    /** 需要解析成参数占位符的形式,防止 sql 注入 -- ? Jdbc */
    public ParseResult parseJdbc(Class<?> model, String source) throws NoSuchMethodException,
            InvocationTargetException, IllegalAccessException {
        return parse(model, source, ExpressionNode::parseJdbc);
    }

    /** 解析专业检索表达式 */
    private ParseResult parse(Class<?> model, String source, Function<ExpressionNode, ParseResult> f) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        // 1. 精确等值匹配
        // CO=XXX
        // 2. 数值范围
        // TP>=XXX
        // 3. 模糊匹配
        // MF%TOSHIBA
        Objects.requireNonNull(StringUtils.trim(source), "检索表达式不能为空");
        Method method = model.getDeclaredMethod("getExpressionCodeMap", null);
        if (Objects.isNull(method)) return null;
        Map<String, IndexExpression> map = (Map<String, IndexExpression>) method.invoke(null, null);
        if (map.isEmpty()) return null;
        // 解析原始检索表达式
        // 操作符类型: = > < >= <= % + -
        // code 后面需要约束的操作符类型
        String[] parts = source.split("and");
        NodeFactory factory = new NodeFactory(map);
        List<ExpressionNode> nodes = Arrays.stream(parts)
                .map(factory::create)
                .filter(Objects::nonNull)
                .collect(Collectors.toList());
        // sql 解析结果拼接
        if (!nodes.isEmpty()) {
            ExpressionNode baseNode = nodes.get(0);
            ParseResult parseResult = f.apply(baseNode);
            nodes.remove(baseNode);
            if (!nodes.isEmpty()) {
                nodes.forEach(n -> parseResult.merge(f.apply(n)));
            }
            // 特殊字段 sql 处理
            replaceWithSpecialSql(map, parseResult);
            // 字段类型处理
            checkFieldTypeNumber(map, parseResult);
            return parseResult;
        }
        return null;
    }

    /**
     * 数值字段类型处理
     */
    private void checkFieldTypeNumber(Map<String, IndexExpression> map, ParseResult parseResult) {
        Map<String, Object> kvMap = parseResult.getKvMap();
        map.values().stream().filter(i -> i.type().equals(IndexExpression.Type.NUMBER)).forEach(i -> {
            String dbField = i.dbField();
            if (kvMap.containsKey(dbField)) {
                Object value = kvMap.get(dbField);
                try {
                    Double valueL = Double.parseDouble(value.toString());
                    kvMap.put(dbField, valueL);
                } catch (Exception e) {
                    throw new CustomException(i.code() + "只支持数值类型的参数");
                }
            }
        });
    }

    /** 替换需要修改的 sql 内容 */
    private void replaceWithSpecialSql(Map<String, IndexExpression> map, ParseResult parseResult) {
        Map<String, Object> kvMap = parseResult.getKvMap();
        map.values().stream().filter(i -> StringUtils.isNotEmpty(i.replaceSql())).forEach(i -> {
            String replaceSql = i.replaceSql();
            String dbField = i.dbField();
            if (kvMap.containsKey(dbField)) {
                parseResult.setRowSql(parseResult.getRowSql().replace(dbField + " ", replaceSql));
            }
        });
    }

    /**
     * 专业检索表达式的解析结果
     */
    public static class ParseResult {
        /** 存储 dbField 字段和其对应的数值,用于 mybatis 的 #{}参数替换 */
        Map<String, Object> kvMap = new LinkedHashMap<>();

        /** 带 #{} 或者 ? 参数占位符的 sql 语句 */
        private String rowSql;

        /** ? 对应的具体参数值 */
        private List<Object> values = new LinkedList<>();

        public ParseResult(String rowSql) {
            this.rowSql = rowSql;
        }

        public ParseResult addValues(Object value) {
            values.add(value);
            return this;
        }

        public List<Object> getValues() {
            return values;
        }

        /**
         * 合并其他 sql 解析结果
         * @param other 其他解析结果
         */
        public ParseResult merge(ParseResult other) {
            Objects.requireNonNull(other);
            String otherSql = other.getRowSql();
            if (otherSql.trim().startsWith("and") || otherSql.trim().startsWith("or")) {
                setRowSql(getRowSql() + otherSql);
            } else {
                setRowSql(getRowSql() + " and " + otherSql);
            }
            Map<String, Object> kvMap = other.getKvMap();
            if (!kvMap.isEmpty()) {
                Map<String, Object> resultKvMap = getKvMap();
                kvMap.forEach((k, v) -> {
                    // 先添加到 values 中
                    addValues(v);
                    if (resultKvMap.containsKey(k)
                            && !v.equals(resultKvMap.get(k))) {

                        // 如果已经有 key 且值不同
                        // 修改 sql 中相应字段
                        String newK = k + "_" + UUID.randomUUID().toString().replace("-", "");
                        String newSql = replaceSqlMultiK(getRowSql(), k, "#{" + newK +"}");
                        setRowSql(newSql);
                        resultKvMap.put(newK, v);
                    }
                    if (!resultKvMap.containsKey(k)) {
                        resultKvMap.put(k, v);
                    }
                });
            }
            return this;
        }

        /**
         * 替换制定字符串
         * @param rowSql place= #{place}  and equip_type_name= #{equip_type_name}  and manu_facturer like concat('%', #{manu_facturer}, '%') or manu_facturer = #{manu_facturer}
         * @param k manu_facturer
         * @param newK xxxxx
         * @return place= #{place}  and equip_type_name= #{equip_type_name}  and manu_facturer like concat('%', #{manu_facturer}, '%') or manu_facturer = xxxxx
         */
        private String replaceSqlMultiK(String rowSql, String k, String newK) {
            String replaced = "#{" + k +"}";
            String toBeReplaced = "_" + replaced;
            int index = rowSql.lastIndexOf(replaced);
            StringBuilder sb = new StringBuilder(rowSql);
            if (index > 0) sb.insert(index, "_");
            return sb.toString().replace(toBeReplaced, newK);
        }

        public ParseResult add(String dbField, Object value) {
            kvMap.putIfAbsent(dbField, value);
            return this;
        }

        public Map<String, Object> getKvMap() {
            return kvMap;
        }

        public String getRowSql() {
            return rowSql;
        }

        public ParseResult setRowSql(String rowSql) {
            this.rowSql = rowSql;
            return this;
        }

    }

    class NodeFactory {

        final Map<String, IndexExpression> expressionMap;

        /** 因为有 "=" 的存在,需要将 "<=", ">=" 放在前面 */
        final List<String> operators = Arrays.asList("<=", ">=", "=", "%", "<",  ">");

        public NodeFactory(Map<String, IndexExpression> expressionMap) {
            this.expressionMap = expressionMap;
        }

        /** 工厂模式创建操作节点 */
        ExpressionNode create(String part) {
            String trim = StringUtils.trim(part);
            if (trim.contains("(")) {
                if (!trim.contains(")")) throw new CustomException("括号应当成对出现");
                if (trim.trim().startsWith("(")) {
                    trim = trim.replace("(", "").replace(")", "");
                }
                // 默认以空格间隔
                String[] split = trim.split(" ");
                if (split.length == 1) {
                    // 单个表达式
                    String node = split[0];
                    return node(node);
                } else if (split.length == 3) {
                    // 表达式嵌套
                    String node = split[0];
                    ExpressionNode expressionNode = node(node);
                    assert expressionNode != null;
                    String operator = split[1];
                    String nodeOrValue = split[2];
                    if (nodeOrValue.contains("(")) {
                        // sub node
                        NodeWithOperator nodeWithOperator = new NodeWithOperator(
                                node(nodeOrValue.replace("(", "").replace(")", "")), operator);
                        expressionNode.addNodeWithOp(nodeWithOperator);
                    } else {
                        // value
                        ValueWithOperator valueWithOperator = new ValueWithOperator(expressionNode.dbField,
                                nodeOrValue, expressionNode, operator);
                        expressionNode.addValueWithOp(valueWithOperator);
                    }
                    return expressionNode;
                }
            } else {
                // 不包含括号也是单个表达式的情况
                return node(trim);
            }
            return null;
        }

        private ExpressionNode node(String expression) {
            for (String operator: operators) {
                if (expression.contains(operator)) {
                    String[] split = expression.split(operator);
                    if (split.length == 2) {
                        String code = split[0];
                        String value = split[1];
                        if (!expressionMap.containsKey(code)) {
                            throw new CustomException("unsupported code: " + code);
                        }
                        String dbField = expressionMap.get(code).dbField();
                        switch (operator) {
                            case "=":
                                return new EqualsExpressionNode(code, operator, value, dbField);
                            case "%":
                                return new ContainsExpressionNode(code, operator, value, dbField);
                            case ">":
                                return new RangeExpressionNode(code, operator, value, dbField);
                            case ">=":
                                return new RangeExpressionNode(code, operator, value, dbField);
                            case "<":
                                return new RangeExpressionNode(code, operator, value, dbField);
                            case "<=":
                                return new RangeExpressionNode(code, operator, value, dbField);
                            default:
                                throw new CustomException("unsupported operator: " + operator);
                        }
                    } else {
                        throw new CustomException("unsupported expression: " + expression);
                    }
                }
            }
            return null;
        }
    }

    abstract class ExpressionNode {
        /** 检索字段代码 */
        protected String code;
        /** 操作符 */
        protected String operator;
        /** 操作数 */
        protected String value;


        protected String dbField;

        public ExpressionNode(String code, String operator, String value,
                              String dbField) {
            this.code = code;
            this.operator = operator;
            this.value = value;
            this.dbField = dbField;
        }

        /** MD%P00HXG - (MF=ABB) 中的 - (MF=ABB) 部分 */
        List<NodeWithOperator> nodeWithOperators = new LinkedList<>();

        /** (MF%TOSHIBA+东芝) 中的 +东芝 部分 */
        List<ValueWithOperator> valueWithOperators = new LinkedList<>();

        public void addNodeWithOp(NodeWithOperator nodeWithOperator) {
            nodeWithOperators.add(nodeWithOperator);
        }

        public void addValueWithOp(ValueWithOperator valueWithOperator) {
            valueWithOperators.add(valueWithOperator);
        }

        public ParseResult parse() {
            String rowSql = parseSql();
            // sub node parser
            ParseResult parseResult = new ParseResult(rowSql)
                    .add(dbField, value)
                    .addValues(value);
            if (CollectionUtils.isNotEmpty(nodeWithOperators)) {
                nodeWithOperators.forEach(n -> parseResult.merge(n.parse()));
            }
            if (CollectionUtils.isNotEmpty(valueWithOperators)) {
                valueWithOperators.forEach(n -> parseResult.merge(n.parse()));
            }
            return parseResult;
        }

        public ParseResult parseJdbc() {
            String rowSql = parseJdbcSql();
            // sub node parser
            ParseResult parseResult = new ParseResult(rowSql)
                    .add(dbField, value)
                    .addValues(value);
            if (CollectionUtils.isNotEmpty(nodeWithOperators)) {
                nodeWithOperators.forEach(n -> parseResult.merge(n.parseJdbc()));
            }
            if (CollectionUtils.isNotEmpty(valueWithOperators)) {
                valueWithOperators.forEach(n -> parseResult.merge(n.parseJdbc()));
            }
            return parseResult;
        }

        /** 使用 #{} 占位符 */
        public abstract String parseSql();

        /** 使用 ? 占位符 */
        public abstract String parseJdbcSql();

        public abstract String parseValueNode(ValueWithOperator valueWithOperator);

    }

    class EqualsExpressionNode extends ExpressionNode {

        public EqualsExpressionNode(String code, String operator, String value, String dbField) {
            super(code, operator, value, dbField);
        }

        @Override
        public String parseSql() {
            return dbField + " = #{" + dbField + "} ";
        }

        @Override
        public String parseJdbcSql() {
            return dbField + " = ? ";
        }

        @Override
        public String parseValueNode(ValueWithOperator valueWithOperator) {
            return " = #{" + valueWithOperator.dbField + "} ";
        }
    }

    class ContainsExpressionNode extends ExpressionNode {

        public ContainsExpressionNode(String code, String operator, String value, String dbField) {
            super(code, operator, value, dbField);
        }

        @Override
        public String parseSql() {
            return dbField + " like concat('%', #{" + dbField + "}, '%') ";
        }

        @Override
        public String parseJdbcSql() {
            return dbField + " like concat('%', ?, '%') ";
        }

        @Override
        public String parseValueNode(ValueWithOperator valueWithOperator) {
            return " like CONCAT('%',#{" + valueWithOperator.dbField + "},'%')";
        }
    }

    class RangeExpressionNode extends ExpressionNode {

        public RangeExpressionNode(String code, String operator, String value, String dbField) {
            super(code, operator, value, dbField);
        }

        @Override
        public String parseSql() {
            switch (operator) {
                case ">":
                    return dbField + " >  #{" + dbField + "} ";
                case ">=":
                    return dbField + " >= #{" + dbField + "} ";
                case "<":
                    return dbField + " <  #{" + dbField + "} ";
                case "<=":
                    return dbField + " <= #{" + dbField + "} ";
                default:
                    throw new CustomException("unsupported operator: " + operator);
            }
        }

        @Override
        public String parseJdbcSql() {
            switch (operator) {
                case ">":
                    return dbField + " >  ? ";
                case ">=":
                    return dbField + " >= ? ";
                case "<":
                    return dbField + " <  ? ";
                case "<=":
                    return dbField + " <= ? ";
                default:
                    throw new CustomException("unsupported operator: " + operator);
            }
        }

        @Override
        public String parseValueNode(ValueWithOperator valueWithOperator) {
            switch (operator) {
                case ">":
                    return " >  #{" + valueWithOperator.dbField + "} ";
                case ">=":
                    return " >= #{" + valueWithOperator.dbField + "} ";
                case "<":
                    return " <  #{" + valueWithOperator.dbField + "} ";
                case "<=":
                    return " <= #{" + valueWithOperator.dbField + "} ";
                default:
                    throw new CustomException("unsupported operator: " + operator);
            }
        }
    }

    /** MD%P00HXG - (MF=ABB) 中的 - (MF=ABB) 部分 */
    class NodeWithOperator {
        private ExpressionNode node;
        private String operator;

        public NodeWithOperator(ExpressionNode node, String operator) {
            this.node = node;
            this.operator = operator;
        }

        public ParseResult parse() {
            ParseResult parseResult = node.parse();
            String rowSql = parseResult.getRowSql();
            switch (operator) {
                case "*":
                    return parseResult.setRowSql("and " + rowSql);
                case "+":
                    return parseResult.setRowSql("or " + rowSql);
                case "-":
                    return parseResult.setRowSql("and " + convertToNotEquals(rowSql));
                default:
                    throw new CustomException("unsupported operator: " + operator);
            }
        }

        public ParseResult parseJdbc() {
            ParseResult parseResult = node.parseJdbc();
            String rowSql = parseResult.getRowSql();
            switch (operator) {
                case "*":
                    return parseResult.setRowSql("and " + rowSql);
                case "+":
                    return parseResult.setRowSql("or " + rowSql);
                case "-":
                    return parseResult.setRowSql("and " + convertToNotEquals(rowSql));
                default:
                    throw new CustomException("unsupported operator: " + operator);
            }
        }

        /** 转换成非 */
        private String convertToNotEquals(String sql) {
            return sql.replace("=", "!=");
        }
    }

    /** (MF%TOSHIBA + 东芝) 中的 +东芝 部分 */
    class ValueWithOperator {
        private String dbField;
        private String value;
        private ExpressionNode fatherNode;
        private String operator;

        public ValueWithOperator(String dbField, String value, ExpressionNode fatherNode, String operator) {
            this.dbField = dbField;
            this.value = value;
            this.fatherNode = fatherNode;
            this.operator = operator;
        }

        public ParseResult parse() {
            // 如 MF%TOSHIBA + 东芝 应解析为 
            // manu_facturer like CONCAT('%', TOSHIBA, '%') or manu_facturer like CONCAT('%', 东芝, '%')
            String subSql = dbField + fatherNode.parseValueNode(this);
            switch (operator) {
                case "*":
                    return new ParseResult("and " + subSql)
                            .add(dbField, value).addValues(value);
                case "+":
                    return new ParseResult("or " + subSql)
                            .add(dbField, value).addValues(value);
                case "-":
                    return new ParseResult("and NOT " + subSql)
                            .add(dbField, value).addValues(value);
                default:
                    throw new CustomException("unsupported operator: " + operator);
            }
        }

        public ParseResult parseJdbc() {
            switch (operator) {
                case "*":
                    return new ParseResult("and " + dbField + " = ? ")
                            .add(dbField, value).addValues(value);
                case "+":
                    return new ParseResult("or " + dbField + " = ? ")
                            .add(dbField, value).addValues(value);
                case "-":
                    return new ParseResult("and " + dbField + " != ? ")
                            .add(dbField, value).addValues(value);
                default:
                    throw new CustomException("unsupported operator: " + operator);
            }
        }

    }

}