java

关注公众号 jb51net

关闭
首页 > 软件编程 > java > Mybatis自定义Sql模板语法

Mybatis自定义Sql模板语法问题

作者:horgn黄小锤、

这篇文章主要介绍了Mybatis自定义Sql模板语法问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

 一、说明

mybatis 原有mapper.xml的语法标签写法过于繁琐,如下面的<if>标签:

 <select id="selectByParams" parameterType="map" resultType="user">    select * from user    <where>      <if test="id != null ">id=#{id}</if>      <if test="name != null and name.length()>0" >and name=#{name}</if>      <if test="age != null and age.length()>0">and age = #{age}</if>    </where>  </select>    <select id="selectByParams" parameterType="map" resultType="user">
    select * from user
    <where>
      <if test="id != null ">id=#{id}</if>
      <if test="name != null and name.length()>0" >and name=#{name}</if>
      <if test="age != null and age.length()>0">and age = #{age}</if>
    </where>
  </select>   

这些写法功能单一还要写很多不必要的代码,于是可用自定义语法代替这类标签,如:

 <select id="selectByParams" parameterType="map" resultType="user">
    select * from user 
        where status = 1
          {{and id = [id]}}
          {{and code = [code,notEmpty]}}
          {{and name like [name,like,notBlank]}}
          {{and id in ([ids,list])}}
 </select>   

使用自定义标签代替 Mybatis 原有的查询条件相关语法,如:if、forEach 标签,解决原有语法写法太繁琐的问题,且增加了一些参数常用的操作和功能。

二、自定义语法和标签说明

1、自定义语法,由于mybatis已经使用了 ${} 和 #{} 标签,所以自定义语法使用 { { }} [ ] 作为标签

2、格式: { {sql [name,key]}} ,如:{ {and code = [code,notEmpty]}} ,本语句的意思:条件 code = ? 在参数code不为空且不为空字符串的条件下生效,比如 code=2 时,生成条件 code = 2。参数名称必须放在第一个位置,关键字可为空。

3、参数说明:以上格式中的 sql 代表一般的sql语句,[ ] 中的 name 代表参数名称,key 代表关键字。

4、关键字类型: notNull notEmpty notBlank trim upper lower list like 、   llike rlike const hidden default 

5、其中  list 、 like 、 llike 、 rlike 、 const 、hidden 为参数类型关键字,notNull 、 notEmpty 、 notBlank 为参数条件类型关键字, trim 、 upper 、 lower 为参数特殊处理关键字。

6、语法标签中所有参数值用英文逗号,分隔,关键字严格区分大小写,除了参数类型关键字,其它关键字可组合使用。

7、参数类型关键字必须放在第二个位置,且一个参数只能是其中一种参数类型,同一个参数不能同时是两个或多个参数类型。

8、条件类型特殊处理类型关键字可组合使用,且位置不固定,可随意搭配。当参数没有参数类型关键字时,条件类型和特殊处理类型关键字可放在第二个位置。

9、参数名称最多可以为三级名称,即:code 、user.code、user.dept.code,名称层级取决于入参类型。当mapper接口的参数为单个对象参数,如:Map、User等,参数名可直接使用 id、code、name等。当参数为两个及以上,需要使用 @Param 定义参数名称,如:func(@Param("user") User user, @Param("map") Map map),则参数名称需要为:user.id, user.name, map.code 等。

三、关键字说明

本关键字用于为参数设置默认值,当参数不满足指定条件时,使用默认值生成sql语句。如:{ {and type = [type,default,1]}},当参数 type 为 null(空) 时,使用默认值生成语句:and type = 1。当参数 type 的值为 2时,生成语句:and type = 2。

注意:默认值关键字 default 后面必须跟一个长度不为0的默认值,否则该关键字不生效。而且default 关键字和默认值必须放在标签最后面。

本关键字用于声明该参数不能为空,即该参数必填,否则会报参数为空异常。如:{ {and code = [code,notNull]}},当参数code为空(null)时,报异常:参数[code]不能为空。但空字符串""不会触发本条件,如需过滤空字符串,则需要与下面两个关键字组合使用。

注意:notNull 和 default 在逻辑上是有冲突的,所以当设置了默认值default,则本关键字会失效。

本关键字用于声明该参数不能为空且长度不能为0,即:str != null && str.length() != 0。且本关键字主要用于字符串参数。如:{ {code = [code,notEmpty]}},当参数code为null或字符串""时,条件不生效。

关键字组合:{ {and code = [code,notNull,notEmpty]}},即参数code必填且不能为空字符串

本关键字与 notEmpty 的区别就是多了一个 trim 操作,即:str != null && str.trim().length() != 0。且本关键字主要用于字符串参数。

notBlank 的功能包含了 notEmpty,所以这两个关键字只需要存在一个即可。

本关键字也是作用于字符串参数,用于自动 trim 字符串前后空格。如:{ {and code = [code,trim]}},当参数code="  a  "时,生成语句:and code = 'a'。

关键字组合:{ {and code = [code,notNull,notBlank,trim]}}

本关键字的作用即自动将参数转换为大写,如参数 code = 'abc' 自动转换成 code = 'ABC'。

关键字组合:{ {and upper(code) = [code,notNull,notBlank,trim,upper]}}

本关键字即把参数自动转换成小写,与 upper 的作用相同。

参数类型关键字:即模糊查询,自动给参数加上 符号。分别对应:'%value%'、'%value'、'value%'。如:{ {and name like [name,like,trim,notEmpty]}}。当参数 name = '  张三  '时,生成语句:and name like '%张三%'。

注意:所有参数类型关键字需要放在标签 [ ] 中的第二个位置,第一个位置是参数名称。

参数类型关键字:本参数作用于 in 查询条件,参数必须是 Controller<?> 类的子类,如 List、ArrayList等。语法:{ {and id in ([ids,list])}}。当参数 ids = [1,2,3] 时,生成语句:and id in (1,2,3)。

注意:list 类型的参数,当参数 ids == null || ids.size() == 0 时,条件都不会生效。

关键字组合

1、default        当 list 与 default 组合使用时,设置默认值需要用 | 代替 , 号,如需要默认值 1,2,3,需要填写成 1|2|3。即:{{and id in (ids,list,default,1|2|3)}}

2、notEmpty、notBlank、trim、upper、lower        当 list 与 以上五个关键字组合使用时,以上五个关键字都是作用于 list 中的元素而不是list本身。如:{{and id in ([ids,trim,notBlank,upper])}},当参数ids=['a',' b ','  ', 'Ef'] 时,生成语句:and id in ('A','B','EF')。其中 trim 关键字去掉了第二个元素的前后空格,notBlank 过滤掉了第三个空字符串元素,upper 则将所有元素转换成大写。

3、notNull        本关键字与list组合即参数ids必填。

参数类型关键字:本关键字与mybatis中的 ${} 类似,即将参数当成sql语句放在sql中,本参数有Sql注入的风险。如:order by {{[orderby,const,notBlank,default,t.id]}}。当参数 orderby 为空或空字符串时,使用默认值生成语句:order by t.id,当参数 orderby = 't.code desc' 时,生成语句:order by t.code desc

注意:当默认值中有逗号 , 时,需要替换成 |, 即:t.id,t.code,t.name 需要写成:t.id|t.code|t.name,如:order by {{[orderby,default,t.id|t.code|t.name]}}

参数类型关键字:本关键字用于隐藏参数本身,即当参数生效时,生成的语句中不包含参数本身。如:{{order by id [orderby,hidden]}}, 当参数 orderby 不为空时(可以是任意值),生成语句:order by id。

四、实现逻辑

1、实现 mybatis 的 InnerInterceptor 拦截器,并在本拦截器中处理自定义语法的逻辑。

public class MyBatisSqlInnerInterceptor implements InnerInterceptor {
    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        String sql = boundSql.getSql(); // 获取原始sql
        //TODO 处理自己的自定义语法逻辑和参数
        String newSql = .....
        // 逻辑处理完成,将新sql放回去
        PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
        mpBoundSql.sql(newSql);
        InnerInterceptor.super.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql);
    }
}

2、将自定义拦截器添加到 mybaits 中。(注意:自定义拦截器必须放在分页插件之前)

@Configuration
public class MybatisPlusConfig {
    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        interceptor.addInnerInterceptor(new MyBatisSqlInnerInterceptor()); // 自定义Sql语法拦截器
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL)); // 分页插件
        return interceptor;
    }
}

五、实现代码

1、新增关键字类型枚举类 

/**
 * 自定义Sql语法关键字类型枚举
 */
public enum SqlWrapperType {
    /** 普通参数,默认值 */
    Param {
        @Override
        String getKey() {
            return "param";
        }
        @Override
        Integer getType() {
            return 0;
        }
    },
    /** 必填参数 */
    NotNull {
        @Override
        String getKey() {
            return "notNull";
        }
        @Override
        Integer getType() {
            return 0;
        }
    },
    /** 参数不能为空或长度为0 */
    NotEmpty {
        @Override
        String getKey() {
            return "notEmpty";
        }
        @Override
        Integer getType() {
            return 0;
        }
    },
    /** 字符串不能为空或 trim() 后长度不能为0 */
    NotBlank {
        @Override
        String getKey() {
            return "notBlank";
        }
        @Override
        Integer getType() {
            return 0;
        }
    },
    /** 如果参数是字符串,则自动trum前后的空格 */
    Trim {
        @Override
        String getKey() {
            return "trim";
        }
        @Override
        Integer getType() {
            return 0;
        }
    },
    /** 参数值自动转换成大写 */
    Upper {
        @Override
        String getKey() {
            return "upper";
        }
        @Override
        Integer getType() {
            return 0;
        }
    },
    /** 参数值自动转换成小写 */
    Lower {
        @Override
        String getKey() {
            return "lower";
        }
        @Override
        Integer getType() {
            return 0;
        }
    },
    /** 模糊查询 %value% */
    Like {
        @Override
        String getKey() {
            return "like";
        }
        @Override
        Integer getType() {
            return 1;
        }
    },
    /** 模糊查询 %value */
    LLike {
        @Override
        String getKey() {
            return "llike";
        }
        @Override
        Integer getType() {
            return 2;
        }
    },
    /** 模糊查询 value% */
    RLike {
        @Override
        String getKey() {
            return "rlike";
        }
        @Override
        Integer getType() {
            return 3;
        }
    },
    /** in 查询 */
    List {
        @Override
        String getKey() {
            return "list";
        }
        @Override
        Integer getType() {
            return 4;
        }
    },
    /** sql语句参数,需要注意Sql注入 */
    Const {
        @Override
        String getKey() {
            return "const";
        }
        @Override
        Integer getType() {
            return 5;
        }
    },
    /** 设置参数默认值 */
    Default {
        @Override
        String getKey() {
            return "default";
        }
        @Override
        Integer getType() {
            return 6;
        }
    },
    /** 隐藏参数值 */
    Hidden {
        @Override
        String getKey() {
            return "hidden";
        }
        @Override
        Integer getType() {
            return 7;
        }
    };
    public static SqlWrapperType get(String key){
        for (SqlWrapperType type : SqlWrapperType.values()) {
            if(type.getKey().equals(key)){
                return type;
            }
        }
        throw new IllegalArgumentException("自定义Sql语法关键字[" + key + "]不存在");
    }
    abstract String getKey();
    abstract Integer getType();
    public boolean equal(SqlWrapperType type){
        return this.getType().equals(type.getType());
    }
}
2、新增参数对象实体类
import lombok.AllArgsConstructor;
import lombok.Data;
/**
 * 自定义Sql语法参数封装器参数对象
 */
@Data
@AllArgsConstructor
public class SqlParamEntity {
    private String key;
    private String sign;
    private SqlWrapperType type;
    private boolean hasDefault;
    private String defaultValue;
    private boolean notNull;
    private boolean notEmpty;
    private boolean notBlank;
    private boolean trim;
    private Integer change;
}

3、新增参数处理类

import com.zorgn.common.ExtList;
import com.zorgn.common.ExtUtil;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static com.zorgn.core.mybatis.SqlWrapper.*;
/**
 * 自定义Sql语法参数封装器
 */
public class SqlParamWrapper {
    private static final String LikeSign = "%";
    private static final String SplitSign = ",";
    private static final Pattern pattern = Pattern.compile("(?<=\\[)(.+?)(?=])");
    private String paramBody;
    private final int paramEntitiesSize;
    private final List<SqlParamEntity> paramEntities = new ArrayList<>();
    /**
     * 自定义Sql语法参数封装器
     */
    public SqlParamWrapper(String paramSql){
        this.paramBody = paramSql;
        initParamSql();
        paramEntitiesSize = paramEntities.size();
    }
    /**
     * 初始化自定义Sql语法参数
     */
    private void initParamSql(){
        Matcher matcher = pattern.matcher(paramBody);
        int count = 1;
        while (matcher.find()){
            String key = matcher.group(), parSign = SignPar + count++ + SignPar, defaultVal = null;
            this.paramBody = this.paramBody.replace(ParBgtStr + key + ParEndStr, parSign);
            SqlWrapperType type = SqlWrapperType.Param;
            boolean hasDefault = false,notNull = false,notEmpty = false,notBlank = false,trim = false;
            Integer change = 0;
            if(key.contains(SplitSign)){
                List<String> keyList = ExtList.splitStr(key, ",", ExtUtil::isBlank);
                key = keyList.get(0).toLowerCase();
                if(keyList.size() >= 2){
                    type = SqlWrapperType.get(keyList.get(1));
                }
                if(keyList.contains(SqlWrapperType.Default.getKey())){
                    int idx = keyList.indexOf(SqlWrapperType.Default.getKey()) + 1;
                    if(keyList.size() > idx){
                        String defstr = keyList.get(idx);
                        if(!isBlank(defstr)) {
                            hasDefault = true;
                            if(SqlWrapperType.List.equal(type)){
                                defaultVal = defstr;
                            }else{
                                defaultVal = defstr.replace("|",",");
                            }
                        }
                    }
                }
                if(keyList.contains(SqlWrapperType.NotNull.getKey())){
                    notNull = true;
                }
                if(keyList.contains(SqlWrapperType.NotEmpty.getKey())){
                    notEmpty = true;
                }
                if(keyList.contains(SqlWrapperType.NotBlank.getKey())){
                    notBlank = true;
                }
                if(keyList.contains(SqlWrapperType.Trim.getKey())){
                    trim = true;
                }
                if(keyList.contains(SqlWrapperType.Upper.getKey())){
                    change = 1;
                }
                if(keyList.contains(SqlWrapperType.Lower.getKey())){
                    change = 2;
                }
            }
            paramEntities.add(new SqlParamEntity(key, parSign, type, hasDefault, defaultVal, notNull, notEmpty, notBlank, trim, change));
        }
    }
    /**
     * 根据参数初始化自定义sql查询条件参数
     */
    public String getParamSql(Map<?,?> map) {
        String newParSql = this.paramBody;
        List<String> notValueKeys = new ArrayList<>();
        for (SqlParamEntity entity : paramEntities) {
            Object val = map.get(entity.getKey());
            if(null == val){
                if(entity.isHasDefault()){
                    val = initDefaultValue(entity);
                }else if(entity.isNotNull()){
                    if(paramEntitiesSize == 1){
                        throw new NullPointerException("参数[" + entity.getKey() + "]不能为空");
                    }else{
                        notValueKeys.add(entity.getKey());
                        continue;
                    }
                }else {
                    if(paramEntitiesSize == 1){
                        return "";
                    }else{
                        notValueKeys.add(entity.getKey());
                        continue;
                    }
                }
            }else if ((isEmpty(val.toString()) && entity.isNotEmpty()) || (isBlank(val.toString()) && entity.isNotBlank())){
                if (entity.isHasDefault()) {
                    val = initDefaultValue(entity);
                } else if (entity.isNotNull()) {
                    throw new NullPointerException("参数[" + entity.getKey() + "]不能为空");
                } else if (paramEntitiesSize > 1){
                    if(notValueKeys.size() < paramEntitiesSize - 1){
                        notValueKeys.add(entity.getKey());
                        continue;
                    } else if (paramEntitiesSize != notValueKeys.size() + 1){
                        notValueKeys.add(entity.getKey());
                        continue;
                    } else if(paramEntitiesSize == notValueKeys.size() + 1){
                        return "";
                    }
                } else {
                    return "";
                }
            }
            if(val instanceof String && entity.isTrim()){
                val = val.toString().trim();
            }
            if(entity.getChange() > 0 && val instanceof String){
                val = entity.getChange() == 1 ? val.toString().toUpperCase() : val.toString().toLowerCase();
            }
            if(SqlWrapperType.Hidden.equal(entity.getType())){
                newParSql = newParSql.replace(entity.getSign(), "");
            } else if(SqlWrapperType.List.equal(entity.getType())){
                if(val instanceof Collection){
                    Collection<?> collection = (Collection<?>) val;
                    if(collection.size() > 0){
                        StringJoiner sj = new StringJoiner(SplitSign);
                        for (Object o : collection) {
                            if(o instanceof String){
                                String str = (String) o;
                                if(entity.isNotBlank() && isBlank(str)){
                                    continue;
                                }
                                if(entity.isNotEmpty() && isEmpty(str)){
                                    continue;
                                }
                                if(entity.isTrim()){
                                    str = str.trim();
                                }
                                if(entity.getChange() > 0){
                                    str = entity.getChange() == 1 ? str.toUpperCase() : str.toLowerCase();
                                }
                                sj.add("'" + str + "'");
                            }else {
                                sj.add(initValue(o) + "");
                            }
                        }
                        newParSql = newParSql.replace(entity.getSign(), sj.toString());
                    }else{
                        notValueKeys.add(entity.getKey());
                    }
                    continue;
                } else {
                    throw new IllegalArgumentException("参数[" + entity.getKey() + "]必须为list集合");
                }
            } else if(SqlWrapperType.Like.equal(entity.getType())){
                val = LikeSign + val + LikeSign;
            } else if(SqlWrapperType.LLike.equal(entity.getType())){
                val = LikeSign + val;
            } else if(SqlWrapperType.RLike.equal(entity.getType())){
                val = val + LikeSign;
            }
            newParSql = newParSql.replace(entity.getSign(), (SqlWrapperType.Const.equal(entity.getType()) ? val : initValue(val)).toString());
        }
        if(notValueKeys.size() > 0 && notValueKeys.size() != paramEntitiesSize){
            //多参数时,所有参数必填
            throw new NullPointerException("参数" + notValueKeys + "不能为空");
        }else if(notValueKeys.size() == paramEntitiesSize){
            //多参数且全部为空,则条件不生产
            newParSql = "";
        }
        return newParSql;
    }
    private Object initDefaultValue(SqlParamEntity entity){
        if(SqlWrapperType.List.equal(entity.getType())){
            return Arrays.asList(entity.getDefaultValue().split("\\|"));
        }else{
            return entity.getDefaultValue();
        }
    }
    private Object initValue(Object value){
        if (null == value) return null;
        if(value instanceof String || value instanceof Character){
            return "'" + value + "'";
        }else {
            return value;
        }
    }
    /** 判断字符串是否为空或 trim() 后长度为0 */
    private static boolean isBlank(String str){
        return null == str || str.trim().length() == 0;
    }
    /** 判断字符串是否为空或长度为0 */
    private static boolean isEmpty(String str){
        return null == str || str.length() == 0;
    }
}

4、新增sql处理类

import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
 * 自定义Sql语法封装器
 */
public class SqlWrapper {
    static final String SqlBgtStr = "{{";
    static final String SqlEndStr = "}}";
    static final String ParBgtStr = "[";
    static final String ParEndStr = "]";
    static final String SignSql = "$";
    static final String SignPar = "?";
    private static final Pattern pattern = Pattern.compile("(?<=\\{\\{)(.+?)(?=}})");
    private String sqlBody;
    private final Map<String, SqlParamWrapper> paramWrappers = new HashMap<>();
    /**
     * 自定义Sql语法封装器
     */
    public SqlWrapper(String sql){
        this.sqlBody = sql;
        initSqlWrapper();
    }
    /**
     * 初始化自定义Sql语法
     */
    private void initSqlWrapper(){
        Matcher matcher = pattern.matcher(this.sqlBody);
        int count = 1;
        while (matcher.find()){
            String group = matcher.group();
            String parSign = SignSql + count++ + SignSql;
            paramWrappers.put(parSign, new SqlParamWrapper(group));
            this.sqlBody = this.sqlBody.replace(SqlBgtStr + group + SqlEndStr, " " + parSign + " ");
        }
    }
    /**
     * 根据参数初始化自定义sql查询条件
     */
    public String getSql(Map<?,?> map) {
        String newSqlBody = this.sqlBody.replace("  ", " ");
        if(null == map || map.size() == 0){
            for (Map.Entry<String, SqlParamWrapper> next : paramWrappers.entrySet()) {
                newSqlBody = newSqlBody.replace(next.getKey(), "");
            }
            return newSqlBody;
        }
        for (Map.Entry<String, SqlParamWrapper> next : paramWrappers.entrySet()) {
            newSqlBody = newSqlBody.replace(next.getKey(), next.getValue().getParamSql(map));
        }
        return newSqlBody;
    }
    /**
     * 判断sql是否包含自定义语法
     */
    public static boolean matcherSql(String sql){
        return pattern.matcher(sql).find();
    }
}

5、添加自定义语法拦截器

import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.zorgn.common.ExtMap;
import com.zorgn.core.mybatis.MyBaitsSqlInnerException;
import com.zorgn.core.mybatis.SqlWrapper;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.math.BigDecimal;
import java.sql.SQLException;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
/**
 * 自定义Sql语法拦截器
 */
public class MyBatisSqlInnerInterceptor implements InnerInterceptor {
    private static final Map<String, SqlWrapper> SQL_WRAPPER_MAP = new HashMap<>();
    private static final Logger logger = LoggerFactory.getLogger("SQL");
    private static final String regex = "[\\s]+"; // 替换空格、换行、tab缩进等;
    private static final String LogPreparing  = "{} - ==>  Preparing: {}";
    private static final String LogParameters = "{} - ==> Parameters: {}";
    private final boolean openThreeLevelParams; // 是否开启三级参数名,默认是二级:user.id,三级:user.creator.id
    /**
     * 构造器:自定义Sql语法拦截器,默认为二级参数名称:user.id
     */
    public MyBatisSqlInnerInterceptor(){
        this(false);
    }
    /**
     * 构造器:自定义Sql语法拦截器,并设置是否开启三级参数名称,三级参数名称:user.creator.id
     */
    public MyBatisSqlInnerInterceptor(Boolean openThreeLevelParams){
        this.openThreeLevelParams = openThreeLevelParams;
    }
    /**
     * Query 前置事件,处理自定义Sql逻辑,打印 Sql 语句和参数
     */
    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        String sql = boundSql.getSql().replaceAll(regex, " ");
        String sqlId = ms.getId();
        if(SQL_WRAPPER_MAP.containsKey(sqlId) || SqlWrapper.matcherSql(sql)) {
            SqlWrapper sqlWrapper = SQL_WRAPPER_MAP.get(sqlId);
            if (null == sqlWrapper) {
                sqlWrapper = new SqlWrapper(sql);
                SQL_WRAPPER_MAP.put(sqlId, sqlWrapper);
                logger.info("初始化Sql模板, SqlId: {}, SQL: {}", sqlId, sql);
            }
            String newSql = null;
            try {
                PluginUtils.MPBoundSql mpBoundSql = PluginUtils.mpBoundSql(boundSql);
                Map<?,?> newParamMap = initParameter(parameter, sqlId);
                newSql = sqlWrapper.getSql(newParamMap);
                mpBoundSql.sql(newSql);
                logger.info(LogPreparing, sqlId, newSql);
                logger.info(LogParameters, sqlId, getSqlParameter(parameter));
            } catch (Exception e) {
                throw new MyBaitsSqlInnerException("Sql模板异常:" + e.getMessage(), sqlId, (null == newSql ? sql : newSql), e);
            }
        }else{
            logger.info(LogPreparing, sqlId, sql);
            logger.info(LogParameters, sqlId, getSqlParameter(parameter));
        }
        InnerInterceptor.super.beforeQuery(executor, ms, parameter, rowBounds, resultHandler, boundSql);
    }
    /**
     * Insert、Update、Delete 前置任务,打印 Sql 语句及参数
     */
    @Override
    public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
        BoundSql boundSql = ms.getBoundSql(parameter);
        String sqlId = ms.getId();
        logger.info(LogPreparing, sqlId, boundSql.getSql().replaceAll(regex, " "));
        logger.info(LogParameters, sqlId, getSqlParameter(parameter));
        InnerInterceptor.super.beforeUpdate(executor, ms, parameter);
    }
    /**
     * 初始化多参数
     */
    private Map<?,?> initParameter(Object parameter, String sqlId){
        if(null == parameter){
            return null;
        }
        // 多参数处理,去除分页参数
        if(parameter instanceof MapperMethod.ParamMap){
            MapperMethod.ParamMap<?> paramMap = (MapperMethod.ParamMap<?>) parameter;
            return initMapperMethodParamMap(paramMap);
        }
        if(parameter instanceof Map){
            return (Map<?,?>)parameter;
        }
        if(isBaseClassTypes(parameter)){
            Map<String,Object> newParamMap = new HashMap<>();
            newParamMap.put(getSqlIdMethodParamName(sqlId, parameter).toLowerCase(), parameter);
            return newParamMap;
        }
        return objectToMap(parameter);
    }
    /**
     * 多参数处理
     */
    private Map<String,Object> initMapperMethodParamMap( MapperMethod.ParamMap<?> paramMap){
        Map<String,Object> newParamMap = new HashMap<>();
        for (String key : paramMap.keySet()) {
            if(key.startsWith("param")){
                continue;
            }
            Object object = paramMap.get(key);
            if(null == object || object instanceof AbstractWrapper){
                // 自定义sql查询语句不需要 QueryWrapper 对象
                continue;
            }
            if(object instanceof Page){
                Page<?> page = (Page<?>) object;
                newParamMap.put("page.current", page.getCurrent());
                newParamMap.put("page.size", page.getSize());
                continue;
            }
            if(isBaseClassTypes(object)){
                newParamMap.put(key, object);
                continue;
            }
            Map<?,?> map = (object instanceof Map) ? (Map<?,?>) object : objectToMap(object);
            for (Map.Entry<?, ?> entry : map.entrySet()) {
                Object value = entry.getValue();
                if(null == value){
                    continue;
                }
                String sonKey = (key + "." + entry.getKey()).toLowerCase();
                if(!openThreeLevelParams || isBaseClassTypes(value)){
                    newParamMap.put(sonKey, value);
                }else{
                    initThreeLevelParams(newParamMap, sonKey, value);
                }
            }
        }
        return newParamMap;
    }
    /**
     * 处理三级参数名称
     */
    private void initThreeLevelParams(Map<String,Object> newParamMap, String sonKey, Object value){
        if(value instanceof Map){
            Map<?,?> sonMap = (Map<?,?>) value;
            for (Map.Entry<?, ?> son : sonMap.entrySet()) {
                Object sonValue = son.getValue();
                if(null == sonValue){
                    continue;
                }
                newParamMap.put(sonKey + "." + son.getKey(), sonValue);
            }
            return;
        }
        for (Map.Entry<?, ?> son : objectToMap(value).entrySet()) {
            Object sonValue = son.getValue();
            if(null == sonValue){
                continue;
            }
            newParamMap.put(sonKey + "." + son.getKey(), sonValue);
        }
    }
    /**
     * 根据 sqlId 获取 mapper 方法参数名称
     */
    private String getSqlIdMethodParamName(String sqlId, Object parameter){
        try {
            int index = sqlId.lastIndexOf(".");
            String mapperId = sqlId.substring(0, index);
            String methodId = sqlId.substring(++ index);
            return Class.forName(mapperId).getMethod(methodId, parameter.getClass()).getParameters()[0].getName();
        } catch (Exception ex){
            throw new RuntimeException("获取Sql模板参数名称失败", ex);
        }
    }
    /**
     * 判断参数类型是否是常用基本类型
     */
    private boolean isBaseClassTypes(Object parameter){
        return parameter instanceof String
                || parameter instanceof Integer
                || parameter instanceof Collection
                || parameter instanceof Long
                || parameter instanceof Date
                || parameter instanceof Double
                || parameter instanceof Boolean
                || parameter instanceof Float
                || parameter instanceof Short
                || parameter instanceof Character
                || parameter instanceof BigDecimal;
    }
    /**
     * 对象属性转map集合
     */
    private Map<?,?> objectToMap(Object parameter){
        return ((JSONObject)JSONObject.toJSON(parameter)).toJavaObject(Map.class);
    }
    /**
     * 获取sql参数,用于日志打印
     */
    private Object getSqlParameter(Object parameter){
        if(null == parameter){
            return null;
        }
        if(parameter instanceof Map){
            Map<Object,Object> map = new HashMap<>();
            Map<?,?> pm = (Map<?,?>) parameter;
            pm.forEach((k, v) -> {
                if(null == v || k.toString().startsWith("param")){
                    return;
                }
                if(v instanceof IPage) {
                    IPage<?> page = (IPage<?>) v;
                    map.put(k, ExtMap.parse("current", page.getCurrent(), "size", page.getSize()));
                    return;
                }
                if(v instanceof AbstractWrapper){
                    map.put(k, getAbstractWrapperInfo(v));
                    return;
                }
                map.put(k,v);
            });
            return map;
        }
        if(parameter instanceof IPage) {
            IPage<?> page = (IPage<?>) parameter;
            return ExtMap.parse("page", ExtMap.parse("current", page.getCurrent(), "size", page.getSize()));
        }
        if(parameter instanceof AbstractWrapper){
            return getAbstractWrapperInfo(parameter);
        }
        return parameter;
    }
    private ExtMap<Object, Object> getAbstractWrapperInfo(Object parameter){
        AbstractWrapper<?,?,?> wrapper = (AbstractWrapper<?,?,?>) parameter;
        ExtMap<Object, Object> vars = new ExtMap<>();
        vars.put("entity", JSONObject.toJSONString(wrapper.getEntity()));
        vars.put("sql_segment", wrapper.getCustomSqlSegment());
        vars.put("sql_vars", wrapper.getParamNameValuePairs());
        return vars;
    }
}

6、将自定义拦截器添加到mysql中,即可使用本自定义sql语法(详情请看 四 - 2 )

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文