Bootstrap

mybatis-plus/mybatis的组件们——拦截器、字段填充器、类型处理器、表名替换、SqlInjector(联合主键处理)

最近有个练手的小例子,大概就是配置两个数据源,从一个数据源读取数据写到另一个数据源,虽然最后做了出来,但是不支持事务。。。就当是对mybatis-plus/mybatis组件使用方式的记录吧,本次例子使用的仍是mybatis-plus

回忆一下mybatis核心对象:

  • Configuration 初始化基础配置,比如MyBatis的别名等,一些重要的类型对象,如,插件,映射器,ObjectFactory和typeHandler对象,MyBatis所有的配置信息都维持在Configuration对象之中
  • SqlSessionFactory  SqlSession工厂
  • SqlSession 作为MyBatis工作的主要顶层API,表示和数据库交互的会话,完成必要数据库增删改查功能
  • Executor MyBatis执行器,是MyBatis 调度的核心,负责SQL语句的生成和查询缓存的维护
  • StatementHandler   封装了JDBC Statement操作,负责对JDBC statement 的操作,如设置参数、将Statement结果集转换成List集合。
  • ParameterHandler   负责对用户传递的参数转换成JDBC Statement 所需要的参数,
  • ResultSetHandler    负责将JDBC返回的ResultSet结果集对象转换成List类型的集合;
  • TypeHandler          负责java数据类型和jdbc数据类型之间的映射和转换
  • MappedStatement   MappedStatement维护了一条<select|update|delete|insert>节点的封装, 
  • SqlSource            负责根据用户传递的parameterObject,动态地生成SQL语句,将信息封装到BoundSql对象中,并返回
  • BoundSql 表示动态生成的SQL语句以及相应的参数信息

组件介绍

拦截器

mybatis可以在执行语句的过程中对特定对象进行拦截调用,主要有四个

  • Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed) 处理增删改查
  • ParameterHandler (getParameterObject, setParameters) 设置预编译参数
  • ResultSetHandler (handleResultSets, handleOutputParameters) 处理结果
  • StatementHandler (prepare, parameterize, batch, update, query) 处理sql预编译,设置参数

这四个是可以拦截的对象,大概的做法是实现mybatis拦截器的接口并在上面添加注解来确定拦截那些方法

下面是接口Interceptor所要实现的方法,setPropertites可以用来初始化,而plugin则包装目标对象供拦截器处理,基于动态代理实现,Plugin类是动态代理类,对实现Interceptor的接口的类进行处理,而实现的拦截器会被加入到拦截器链进行处理

 Object intercept(Invocation var1) throws Throwable;

    default Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    default void setProperties(Properties properties) {
    }

plugin.warp方法

拦截器链:

public class InterceptorChain {
    private final List<Interceptor> interceptors = new ArrayList();

    public InterceptorChain() {
    }

    public Object pluginAll(Object target) {
        Interceptor interceptor;
        for(Iterator var2 = this.interceptors.iterator(); var2.hasNext(); target = interceptor.plugin(target)) {
            interceptor = (Interceptor)var2.next();
        }

        return target;
    }

    public void addInterceptor(Interceptor interceptor) {
        this.interceptors.add(interceptor);
    }

    public List<Interceptor> getInterceptors() {
        return Collections.unmodifiableList(this.interceptors);
    }
}

并在handler里面添加这些拦截器类,执行pluginAll方法,返回一个经过代理链处理的对象

实现该接口以后,要添加注解来表明拦截哪些方法,方法则是上面四个对象的拥有的方法。下面这个注解则是指定了拦截哪些对象的哪个方法,args则是被拦截方法的参数

public @interface Signature {
    Class<?> type();

    String method();

    Class<?>[] args();
}

比如这个例子

        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})

Signature注解就对应上面的接口、方法及其参数,然后在拦截器添加一个@Intercepts,这个注解的内容是Signature注解数组

有了拦截器,初步想法是根据方法拦截,如果select则使用读数据源,增删改则使用写数据源,这个其实原理和之前写的一篇代码级别读写分离很相似,也是通过ThreadLocal存放当前线程的数据源,然后通过拦截器来判断用哪个数据源,交由AbstarctRoutingDataSource来根据ThreadLoacl里面的值来处理。

但是有个问题,两个数据源转换,表名、字段名不一定相等,比如从pgsql的一个叫user_info表里的数据转到mysql叫user表的数据,字段名都不相同

我的处理方法是查询对象的目标的字段名为准,然后给每个字段一个注解指向修改对象的数据源表字段名,如果查询目标表没有插入目标表的字段,便在select的时候默认select null或者用代码限定查询的字段。这里首先先定义了三个注解,分别对应查、改相应的数据源、表名、字段

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface TranDB {
    DBType from();
    DBType to();
    Class object();
}

public @interface TranField {
    String from() default "";

    String to();

    String empty = "null";
}

public @interface TranTable {
    String from();
    String to();
}

User类

@TranTable(from = "user_info", to = "user")
public class User {

    @TranField(to = "id")
    @TableId
    private Integer userId;
    @TranField(to = "wx_nickname")
    private String userAccount;
    @TranField(to = "roles")
    private String mobile;
    @TranField(from=TranField.empty,to="create_time")
    private Date createTime;
    @TranField(from=TranField.empty,to="update_time")
    private Date updateTime;
    @TranField(from=TranField.empty,to="bonus")
    private Integer bonus;
    @TranField(to="wx_id")
    private String[] test;
}

UserMapper

@TranDB(from = DBType.PGSQL,to=DBType.MYSQL,object=User.class)
public interface UserMapper extends BaseMapper<User> {
}

这里添加一个缓存mapper信息类,方便在拦截器中调用,其中有个成员变量是用来存储mapperName对应的TranDB注解信息,拦截器通过拦截的方法获取mapper名称,再通过这个mapper信息类获取他的TranDB注解,这个注解里面有对应的实体class,可以用来获取字段信息注解及表名信息注解,而另一个成员变量则是用来存放待会说到的表名替换,这里面实现了两个接口,一个通过spring容器加载资源的接口,另一个则是用来初始化bean的。

package com.trendy.task.transport.config;

import com.baomidou.mybatisplus.extension.parsers.ITableNameHandler;
import com.trendy.task.transport.annotations.TranDB;
import com.trendy.task.transport.handler.SelfTableNameHandler;
import com.trendy.task.transport.util.CamelHumpUtils;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ResourceLoaderAware;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternUtils;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;

import java.util.*;

/**
 * Mapper信息缓存类
 */
public class MapperAuxFeatureMap implements ResourceLoaderAware, InitializingBean {


    private static ResourceLoader resourceLoader;

    @Value("${tran.mapperlocation}")
    public   String MAPPER_LOCATION ;

    public static final String TABLEPREFIX="t_";

    //表名处理
    public  Map<String, ITableNameHandler> tableNameHandlerMap;

    //mapper文件的注解
    public  Map<String, TranDB> mapperTranDbMap;

    //通过方法获取mapper名称
    public static String getMapperNameFromMethodName(String source){
        int end = source.lastIndexOf(".") + 1;
        String mapper = source.substring(0, end - 1);
        mapper = mapper.substring(mapper.lastIndexOf(".") + 1);
        return mapper;
    }
    
    @Override
    public void setResourceLoader(ResourceLoader resourceLoader) {
       MapperAuxFeatureMap.resourceLoader=resourceLoader;
    }
    @Override
    public void afterPropertiesSet() throws Exception {
        ResourcePatternResolver resolver = ResourcePatternUtils.getResourcePatternResolver(resourceLoader);
        MetadataReaderFactory metaReader = new CachingMetadataReaderFactory(resourceLoader);
        Resource[] resources = resolver.getResources("classpath*:"+MAPPER_LOCATION.replace(".","/")+"/**/*.class");
        mapperTranDbMap = new HashMap<>();
        tableNameHandlerMap = new HashMap<>();
        for (Resource r : resources) {
            MetadataReader reader = metaReader.getMetadataReader(r);
            String className = reader.getClassMetadata().getClassName();
            Class<?> c = Class.forName(className);
            if (c.isAnnotationPresent(TranDB.class)) {
                String name = c.getSimpleName();
                TranDB tranDB = c.getAnnotation(TranDB.class);
                mapperTranDbMap.put(name, tranDB);
                String value = tranDB.object().getSimpleName();
                tableNameHandlerMap.put(TABLEPREFIX+ CamelHumpUtils.humpToLine(value),new SelfTableNameHandler(tranDB.object()));
            }
        }
    }
}

替换数据源的部分代码,对query和update(即增删改)方法进行拦截,改方法使用mysql数据源,查方法使用pgsql数据源

package com.trendy.task.transport.dyma;


import com.trendy.task.transport.annotations.TranDB;
import com.trendy.task.transport.config.MapperAuxFeatureMap;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.util.Properties;

/**
 * @author: lele
 * @date: 2019/10/23 下午4:24
 */
@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class,
                RowBounds.class, ResultHandler.class})
})
public class DynamicDataSourceInterceptor implements Interceptor {

    private MapperAuxFeatureMap mapperAuxFeatureMap;

    public DynamicDataSourceInterceptor(MapperAuxFeatureMap mapperAuxFeatureMap) {
        this.mapperAuxFeatureMap = mapperAuxFeatureMap;
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        //如果读取数据,使用From的库,否则使用To库
        DBType db =null;
        Object[] objects = invocation.getArgs();
        MappedStatement statement = (MappedStatement) objects[0];
        String mapper = MapperAuxFeatureMap.getMapperNameFromMethodName(statement.getId());
        TranDB tranDB = mapperAuxFeatureMap.mapperTranDbMap.get(mapper);
        if (statement.getSqlCommandType().equals(SqlCommandType.SELECT)) {
            db = tranDB.from();
        } else {
            db = tranDB.to();
        }
        DynamicDataSourceHolder.setDbType(db);
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object o) {
        if (o instanceof Executor) {
            return Plugin.wrap(o, this);
        } else {
            return o;
        }
    }

    @Override
    public void setProperties(Properties properties) {

    }
}

然后对字段进行修改的拦截器,这里为什么要继承AbstactSqlPaserHandler呢,因为可以复用他的方法,以及为后来加入表名替换的类做准备,这里的流程是获取原来字段的名字,并改为TranField的to所存储的内容

package com.trendy.task.transport.handler;

import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.extension.handlers.AbstractSqlParserHandler;
import com.trendy.task.transport.annotations.TranDB;
import com.trendy.task.transport.annotations.TranField;
import com.trendy.task.transport.config.MapperAuxFeatureMap;
import com.trendy.task.transport.util.CamelHumpUtils;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.Statement;
import java.util.HashMap;
import java.util.Map;

/**
 * @author: lele
 * @date: 2019/10/23 下午5:12
 */

@Intercepts({
        @Signature(
                type = StatementHandler.class,
                method = "prepare",
                args = {Connection.class, Integer.class}
        ),
        @Signature(
                type = StatementHandler.class,
                method = "update",
                args = {Statement.class}
        ),
        @Signature(
                type = StatementHandler.class,
                method = "batch",
                args = {Statement.class}
        )
})
public class FieldHandler extends AbstractSqlParserHandler implements Interceptor {
    private MapperAuxFeatureMap mapperAuxFeatureMap;

    public FieldHandler(MapperAuxFeatureMap mapperAuxFeatureMap) {
        this.mapperAuxFeatureMap = mapperAuxFeatureMap;
    }

    @Override
    public Object plugin(Object target) {
        return target instanceof StatementHandler ? Plugin.wrap(target, this) : target;
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler =  PluginUtils.realTarget(invocation.getTarget());
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
        super.sqlParser(metaObject);
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        Boolean select = mappedStatement.getSqlCommandType().equals(SqlCommandType.SELECT);
        if (!select) {
            //通过获取mapper名称从缓存类中获取对应的注解
            String mapperName = MapperAuxFeatureMap.getMapperNameFromMethodName(mappedStatement.getId());
            TranDB tranDB = mapperAuxFeatureMap.mapperTranDbMap.get(mapperName);
           //获取类的所有属性
            Class clazz = tranDB.object();
            Map<String, Field> mapField = new HashMap<>(clazz.getFields().length);
            while (!clazz.equals(Object.class)) {
                Field[] fields = clazz.getDeclaredFields();
                for (Field field : fields) {
                    field.setAccessible(true);
                    mapField.put(field.getName(), field);
                }
                clazz = clazz.getSuperclass();
            }
            //替换sql
            String sql = boundSql.getSql();
            for (Map.Entry<String, Field> entry : mapField.entrySet()) {
                String sqlFieldName = CamelHumpUtils.humpToLine(entry.getKey());
                if (sql.contains(sqlFieldName)) {
                    String from = entry.getValue().getAnnotation(TranField.class).to();
                    sql = sql.replaceAll(sqlFieldName, from);
                }
            }
            metaObject.setValue("delegate.boundSql.sql", sql);
        }
        return invocation.proceed();
    }
}

现在还有一个问题要处理,就是表名替换,但是这个有个小坑,这个功能也相当于上面替换sql的功能比如insert into user(user_info,user_id) values ...,比如把user这个表名替换为user_info这个表来执行,此时的插入语句会把所有user的替换成user_info,这时候官方的建议是用@TableName这个注解更改表名避免出现这个情况

表名处理器

使用:实现ITableNameHandler,并实现接口方法返回一个表名字

package com.trendy.task.transport.handler;

import com.baomidou.mybatisplus.extension.parsers.ITableNameHandler;
import com.trendy.task.transport.annotations.TranTable;
import org.apache.ibatis.reflection.MetaObject;

/**
 * @author: lele
 * @date: 2019/10/24 下午2:39
 * 表名替换的handler
 */
public class SelfTableNameHandler implements ITableNameHandler {
    private final Class clazz;

    public SelfTableNameHandler(Class clazz) {
        this.clazz = clazz;
    }

    @Override
    public String dynamicTableName(MetaObject metaObject, String sql, String tableName) {
        TranTable t = (TranTable) clazz.getAnnotation(TranTable.class);
        if (sql.toLowerCase().startsWith("select")) {
            return t.from();
        } else {
            return t.to();
        }
    }
}

也可以注入到mp自带的分页那个拦截器中

 @Bean
    public FieldHandler fieldHandler() {
        FieldHandler f = new FieldHandler(mapperAuxFeatureMap());
        DynamicTableNameParser t = new DynamicTableNameParser();
        t.setTableNameHandlerMap(mapperAuxFeatureMap().tableNameHandlerMap);
        f.setSqlParserList(Collections.singletonList(t));
        return f;
    }

字段填充器

mysql的表里有创建时间、修改时间、积分,对于这些字段,而pgsql表里面没有,这时候想在插入时使用一个默认的值,这时候可以使用字段填充器

做法:实现MetaObjectHandler接口,并重写里面的方法,然后在字段的@TableField的fill类型里面说明需要填充时情况

例子,对createTime,updateTime,bouns进行默认填充,使用getFieldValueByName和setFieldValByName方法进行赋值

package com.trendy.task.transport.handler;

import com.baomidou.mybatisplus.core.handlers.MetaObjectHandler;
import org.apache.ibatis.reflection.MetaObject;
import java.util.Date;


/**
 * @author lulu
 * @Date 2019/10/24 23:15
 * 自动填充的handler
 */

public class DefaultFieldValueHandler implements MetaObjectHandler {

    public static final String CREATETIME = "createTime";
    public static final String UPDATETIME = "updateTime";
    public static final String BOUNS = "bonus";

    private void handle(String name, MetaObject metaObject, Object target) {
        Object o = getFieldValByName(name, metaObject);
        if (o == null) {
            setFieldValByName(name, target, metaObject);
        }
    }

    @Override
    public void insertFill(MetaObject metaObject) {
        handle(CREATETIME, metaObject, new Date());
        handle(UPDATETIME, metaObject, new Date());
        handle(BOUNS, metaObject, 500);
    }

    @Override
    public void updateFill(MetaObject metaObject) {
        handle(UPDATETIME, metaObject, new Date());
    }
}

然后为user添加注解@TableField的注解

  @TranField(from=TranField.empty,to="create_time")
    @TableField(fill = FieldFill.INSERT)
    private Date createTime;
    @TranField(from=TranField.empty,to="update_time")
    @TableField(fill = FieldFill.INSERT_UPDATE)
    private Date updateTime;
    @TranField(from=TranField.empty,to="bonus")
    @TableField(fill= FieldFill.INSERT)
    private Integer bonus;

然后在全局配置中加入这个字段填充器类

 @Bean
    public GlobalConfig globalConfig() {
        GlobalConfig globalConfig = new GlobalConfig();
        globalConfig.setBanner(false);
        globalConfig.setMetaObjectHandler(defaultFieldValueHandler());
        return globalConfig;
    }

 工厂类配置上全局配置 sqlSessionFactory.setGlobalConfig(globalConfig());

类型处理器

类型处理器,用于 JavaType 与 JdbcType 之间的转换,用于 PreparedStatement 设置参数值和从 ResultSet 或 CallableStatement 中取出一个值,比如把String、Integer、Long、Double放到数据库里面用逗号分隔形式存储,此时可以

自定义类型处理器,这里针对上面四个对象数组实现类型转换处理,setxxx方法主要是对参数进行处理,然后把处理后的结果放入数据库中,而get方法则处理从数据库取出来的数据该如何处理,这里接受一个lambda函数作为方法转换,即字符串-》目标类型,这里定义一个抽象类统一处理方法,然后具体转换方法由子类实现

package com.trendy.task.transport.handler;

import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.MappedJdbcTypes;
import org.apache.ibatis.type.MappedTypes;

import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.function.Function;

/**
 * @author lulu
 * @Date 2019/10/25 21:46
 */
//@MappedJdbcTypes({})表明处理哪种jdbc类型
@MappedTypes(Object[].class)//表明处理哪种javatype
public abstract class AbstractArrayTypeHandler<T> extends BaseTypeHandler<Object[]> {

    //这里接受一个lambdah函数做转换处理
   private final Function<String,T> method;

    public AbstractArrayTypeHandler(Function<String,T> method){
        this.method=method;
    }

    @Override
    public void setNonNullParameter(PreparedStatement preparedStatement, int i, Object[] objects, JdbcType jdbcType) throws SQLException {
            StringBuilder sb=new StringBuilder();
            for(Object o:objects){
                sb.append(o.toString()+",");
            }
            sb.deleteCharAt(sb.length()-1);
            preparedStatement.setString(i,sb.toString());
    }

    @Override
    public T[] getNullableResult(ResultSet resultSet, String s) throws SQLException {
        return getArray(resultSet.getString(s));
    }

    @Override
    public T[] getNullableResult(ResultSet resultSet, int i) throws SQLException {
        return getArray(resultSet.getString(i));
    }
    @Override
    public Object[] getNullableResult(CallableStatement callableStatement, int i) throws SQLException {
        return  getArray(callableStatement.getString(i));
    }

    protected  T[] getArray(String source){
        if(source==null){
            return null;
        }
        String[] resString=source.split(",");
        if(this.method==null){
            return (T[])resString;
        }
        T[] resArray= (T[]) new Object[resString.length];
        for (int i = 0; i < resString.length; i++) {
            resArray[i]=method.apply(resString[i]);
        }
        return resArray;
    }

}

然后定义一个工厂存放子类

package com.trendy.task.transport.handler;

import java.util.function.Function;

/**
 * @author lulu
 * @Date 2019/10/25 22:33
 */
public interface ArrayTypeHandlerFactory {

     class IntegerArrayTypeHandler extends AbstractArrayTypeHandler<Integer>{
        public IntegerArrayTypeHandler() {
            super(Integer::parseInt);
        }
    }
     class DoubleArrayTypeHandler extends AbstractArrayTypeHandler<Double>{
        public DoubleArrayTypeHandler(){
            super(Double::parseDouble);
        }
    }
     class LongArrayTypeHandler extends AbstractArrayTypeHandler<Long>{
        public LongArrayTypeHandler(){
            super(Long::parseLong);
        }
    }
     class StringArrayTypeHandler extends AbstractArrayTypeHandler<String>{
        public StringArrayTypeHandler(){
            super(null);
        }
    }



}


指定处理类型

    @TableField(typeHandler = ArrayTypeHandlerFactory.StringArrayTypeHandler.class)

好了,现在来测试下,

 @Test
    public void selectById() {
        List<User> userList = userService.list(new LambdaQueryWrapper<User>().select(User::getUserId,User::getMobile,User::getUserAccount,User::getTest));
        userService.saveBatch(userList);
    }

大概就到这里,代码完整版的github地址:GitHub - 97lele/transport

但是这个有个缺陷,就是不支持事务,还有saveOrUpdate方法也不支持,因为两个数据源都不一样,他是先查,看是否有再做更新或者插入操作,这些问题仍需解决,且当作一个使用方法记录的小例子吧

2021.02.05补充

mybatis-plus默认的baseMapper不支持批量的插入和更新,有时候会为了这个批量的方法去继承serviceImpl类,而这个类有没有其他的业务代码在里面,继承仅仅是为了获取批量操作的方法,会显得这个类有点”贫血“

对于这种情况,偶然看到一个开源的项目onemall对BaseMapper做了扩展,使它支持了批量插入的做法,在此基础下,我也做了一定的扩展

根据自己的见解,简单描述下流程:

核心就是自定义一个sql自动注入器,mybatis-plus会扫描mapper类 ,为每个mapper类构造一个MappedStatement(相当于mapper里的一个sql)

添加到配置类里面(具体在下面(2)),只会在初始化时候执行一次,而我们要做的,就是把某个特定mapper,注入我们想要的模板方法

具体注入在MybatisMapperAnnotationBuilder,注入动态sql,getSqlInjector获取sql注入器,并为当前的MapperBuilderAssistant对象(类似于mapper)注入mappedStatement(sql)

(1)注入开始,里面的inspectInject就是

inspectInject就会把abstractMethod一一注入进去

public void inspectInject(MapperBuilderAssistant builderAssistant, Class<?> mapperClass) {
        Class<?> modelClass = extractModelClass(mapperClass);
        if (modelClass != null) {
            String className = mapperClass.toString();
            Set<String> mapperRegistryCache = GlobalConfigUtils.getMapperRegistryCache(builderAssistant.getConfiguration());
            if (!mapperRegistryCache.contains(className)) {
                List<AbstractMethod> methodList = this.getMethodList(mapperClass);
                if (CollectionUtils.isNotEmpty(methodList)) {
                    TableInfo tableInfo = TableInfoHelper.initTableInfo(builderAssistant, modelClass);
                    // 循环注入自定义方法
                    methodList.forEach(m -> m.inject(builderAssistant, mapperClass, modelClass, tableInfo));
                } else {
                    logger.debug(mapperClass.toString() + ", No effective injection method was found.");
                }
                mapperRegistryCache.add(className);
            }
        }
    }

((2)抽象方法注入类,最终放到configuration类缓存,形成一个sql

具体看看AbstractMethod的实现类,而默认的DefaultSqlInject类就有我们baseMapper所包含的默认方法,我们要添加多两个方法,批量更新和批量插入

下面开始编码

首先继承AbstarctMethod类,实现我们想注入的sql模板

对于批量插入,使用 insert into table columns values (values),(values),(values)

对于批量更新,使用 update table set xx=xx,xx=xx where xx=xx;update table set xx=xx,xx=xx where xx=xx;update table set xx=xx,xx=xx where xx=xx;进行分号分割

因为要用到标签解析的功能,需要用script包裹起来。


具体初始化模板如下,   对其进行填充

public enum CustomSqlMethodEnum {
    /**
     * 批量插入
     */
    INSERT_BATCH("insertBatch",
            "批量插入",
            "<script>\n"
                    + "INSERT INTO %s %s VALUES \n"
                    + "<foreach collection=\"collection\"  item=\"item\" separator=\",\"> %s\n </foreach>\n"
                    + "</script>"),
    /**
     * 批量更新
     */
    UPDATE_BATCH("updateBatchByIds",
            "批量更新",
            "<script>\n" +
                    "<foreach collection=\"collection\" item=\"item\" separator=\";\"> update %s set %s where %s </foreach>\n"
                    + "</script>"
    ),

    /**
     * 根据联合键查询
     */
    SELECT_BY_COMPOSEKEYS("selectByComposeKeys",
            "联合主键查询",
            "<script>" +
                    " select <choose> <when test=\"ew!=null and ew.sqlSelect != null and ew.sqlSelect != ''\"> ${ew.sqlSelect} </when> <otherwise> * </otherwise> </choose> from %s "
                    + "<where> <foreach collection=\"collection\" item=\"item\" open=\"(\" close=\")\" separator=\"or\"> ( %s ) </foreach> <if test=\"ew!=null and ew.sqlSegment != null and ew.sqlSegment != ''\">\n" +
                    "AND ${ew.sqlSegment}\n" +
                    "</if> </where> </script>"
    ),

    /**
     * 根据联合键查询
     */
    SELECT_IDS_BY_COMPOSEKEYS("selectIdsByComposeKeys",
            "联合主键查询id",
            "<script>" +
                    " select %s from %s "
                    + "<where> <foreach collection=\"collection\" item=\"item\" open=\"(\" close=\")\" separator=\"or\"> ( %s ) </foreach> </where> </script>"
    ),

    /**
     * 根据联合主键删除
     */
    DELETE_BY_COMPOSEKEYS("deleteByComposeKeys",
            "联合主键删除",
            "<script>" +
                    " delete from %s "
                    + "<where> <foreach collection=\"collection\" item=\"item\" open=\"(\" close=\")\" separator=\"or\"> ( %s ) </foreach> </where> </script>"
    ),
    /**
     * 根据联合主键更新
     */
    UPDATE_BY_COMPOSEKEYS("updateByComposeKeys"
            , "联合主键批量修改",
            "<script>\n" +
                    "<foreach collection=\"collection\" item=\"item\" separator=\";\"> update %s set %s where %s </foreach>\n"
                    + "</script>"
    );


    private final String method;
    private final String desc;
    private final String sql;

    CustomSqlMethodEnum(String method, String desc, String sql) {
        this.method = method;
        this.desc = desc;
        this.sql = sql;
    }

    public String getMethod() {
        return method;
    }

    public String getDesc() {
        return desc;
    }

    public String getSql() {
        return sql;
    }
}

下面是构造批量插入和更新的代码逻辑

批量插入:主要是构造——表名、插入列、插入值,这里为了防止覆盖数据库默认值,也可以做if test 的那种判空,来取消插入,不过比较麻烦,这里就不做了

public class InsertBatch extends AbstractMethod {
    @Override
    public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        KeyGenerator keyGenerator = new NoKeyGenerator();

        CustomSqlMethodEnum sqlMethod = CustomSqlMethodEnum.INSERT_BATCH;

        // ==== 拼接 sql 模板 ==============
        StringBuilder columnScriptBuilder = new StringBuilder(LEFT_BRACKET);
        StringBuilder valuesScriptBuilder = new StringBuilder(LEFT_BRACKET);
        // 主键拼接
        if (StringUtils.isNotBlank(tableInfo.getKeyColumn())) {
            //(x1,x2,x3
            columnScriptBuilder.append(tableInfo.getKeyColumn()).append(COMMA);
            //(item.xx,item.xx
            valuesScriptBuilder.append(SqlScriptUtils.safeParam(SQLConditionWrapper.ITEM + DOT + tableInfo.getKeyProperty())).append(COMMA);
        }
        // 普通字段拼接
        List<TableFieldInfo> fieldList = tableInfo.getFieldList();
        for (TableFieldInfo fieldInfo : fieldList) {
            //有更新默认填充器的不参与赋值
            if (!fieldInfo.isWithInsertFill() && fieldInfo.isWithUpdateFill()) {
                continue;
            }
            columnScriptBuilder.append(fieldInfo.getColumn()).append(COMMA);
            valuesScriptBuilder.append(SqlScriptUtils.safeParam(SQLConditionWrapper.getCondition(fieldInfo.getProperty()).toString())).append(COMMA);
        }
        // 替换多余的逗号为括号
        //(x1,x2)
        columnScriptBuilder.setCharAt(columnScriptBuilder.length() - 1, ')');
        //(item.xx,item.xx2)
        valuesScriptBuilder.setCharAt(valuesScriptBuilder.length() - 1, ')');
        // sql 模板占位符替换
        String columnScript = columnScriptBuilder.toString();
        String valuesScript = valuesScriptBuilder.toString();
        String sql = String.format(sqlMethod.getSql(), tableInfo.getTableName(), columnScript, valuesScript);


        // === mybatis 主键逻辑处理:主键生成策略,以及主键回填=======
        String keyColumn = null;
        String keyProperty = null;
        // 表包含主键处理逻辑,如果不包含主键当普通字段处理
        if (StringUtils.isNotBlank(tableInfo.getKeyProperty())) {
            if (tableInfo.getIdType() == IdType.AUTO) {
                /** 自增主键 */
                keyGenerator = new Jdbc3KeyGenerator();
                keyProperty = tableInfo.getKeyProperty();
                keyColumn = tableInfo.getKeyColumn();
            } else {
                if (null != tableInfo.getKeySequence()) {
                    keyGenerator = TableInfoHelper.genKeyGenerator(sqlMethod.getMethod(), tableInfo, builderAssistant);
                    keyProperty = tableInfo.getKeyProperty();
                    keyColumn = tableInfo.getKeyColumn();
                }
            }
        }
        // 模板写入
        SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
        return this.addInsertMappedStatement(mapperClass, modelClass, sqlMethod.getMethod(), sqlSource, keyGenerator, keyProperty, keyColumn);
    }
}

更新sql构造,这里要注意字段上的字段策略(FieldStrategy),附上处理类

public class UpdateBatchByIds extends AbstractMethod {
    @Override
    public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        CustomSqlMethodEnum sqlMethod = CustomSqlMethodEnum.UPDATE_BATCH;
        /**
         * update table set <if test ="item.pro!=null">key1=#{item.pro}</if>, key2=#{item.pro2} where keycolom = %s
         * */
        StringBuilder withChevScript = new StringBuilder();
        StringBuilder lastFiledScriptBuilder=new StringBuilder();
        StringBuilder keyScriptBuilder = new StringBuilder();
        if (StringUtils.isNotBlank(tableInfo.getKeyColumn())) {
            keyScriptBuilder.append(tableInfo.getKeyColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.ITEM + DOT + tableInfo.getKeyProperty()));
        }
        List<TableFieldInfo> fieldList = tableInfo.getFieldList();
        for (int i = 0; i < fieldList.size(); i++) {
            TableFieldInfo fieldInfo = fieldList.get(i);
            Boolean isLast=(i==(fieldList.size()-1));
            //有插入默认填充器的不参与赋值
            if (fieldInfo.isWithInsertFill() && !fieldInfo.isWithUpdateFill()) {
                continue;
            }
            boolean change = false;
            if (Objects.equals(fieldInfo.getUpdateStrategy(), FieldStrategy.NOT_NULL) && !fieldInfo.isWithUpdateFill()) {
                SQLConditionWrapper.appendNotNull(withChevScript, fieldInfo.getProperty());
                change = true;
            }
            if (Objects.equals(fieldInfo.getUpdateStrategy(), FieldStrategy.NOT_EMPTY) && !fieldInfo.isWithUpdateFill()) {
                SQLConditionWrapper.appendNotEmpty(withChevScript, fieldInfo.getProperty());
                change = true;
            }
            if (change) {
                withChevScript.append(fieldInfo.getColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.getCondition(fieldInfo.getProperty()).toString()));
                withChevScript.append(COMMA).append("</if>");
                if(isLast&&lastFiledScriptBuilder.length()==0){
                    //如果没有其他字段替补,弄个占位的,不然会出现语法错误的情况
                    withChevScript.append(tableInfo.getKeyColumn()).append(EQUALS).append(tableInfo.getKeyColumn());
                }
            }else{
                lastFiledScriptBuilder.append(fieldInfo.getColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.getCondition(fieldInfo.getProperty()).toString())).append(COMMA);
            }
        }
        //处理多余的逗号
        if(!StringUtils.isBlank(lastFiledScriptBuilder)&&!StringUtils.isBlank(withChevScript)){
            int leftChevIndex= withChevScript.lastIndexOf(LEFT_CHEV);
            if(withChevScript.charAt(leftChevIndex-1)!=','){
                withChevScript.replace(leftChevIndex,leftChevIndex+1,",<");
            }
            if(lastFiledScriptBuilder.lastIndexOf(COMMA)==lastFiledScriptBuilder.length()-1){
                lastFiledScriptBuilder.deleteCharAt(lastFiledScriptBuilder.length()-1);
            }
        }
        withChevScript.append(lastFiledScriptBuilder);
        // sql 模板占位符替换
        String sql = String.format(sqlMethod.getSql(), tableInfo.getTableName(), withChevScript, keyScriptBuilder);
        // 模板写入
        SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
        return this.addUpdateMappedStatement(mapperClass, modelClass, sqlMethod.getMethod(), sqlSource);
    }
}
public class SQLConditionWrapper {
    public final static String ITEM = "item";

    public static StringBuilder appendNotNull(StringBuilder builder, String property) {
        return appendEnd(appendStart(builder, getCondition(property)));
    }

    public static StringBuilder appendNotEmpty(StringBuilder builder, String property) {
        StringBuilder condition = getCondition(property);
        StringBuilder append = appendStart(builder, condition)
                .append(" ").append(AND).append(" ").append(condition).append(EXCLAMATION_MARK).append(EQUALS).append(" ").append("''");
        return appendEnd(append);
    }

    private static StringBuilder appendEnd(StringBuilder builder) {
        builder.append("\"")
                .append(RIGHT_CHEV);
        return builder;
    }

    private static StringBuilder appendStart(StringBuilder builder, StringBuilder item) {
        builder.append(LEFT_CHEV)
                .append("if test=\"")
                .append(item).append(EXCLAMATION_MARK).append(EQUALS).append(NULL);
        return builder;
    }

    public static StringBuilder getCondition(String property) {
        return new StringBuilder()
                .append(ITEM).append(DOT).append(property);
    }
}

然后就自定义mapper及sql注入器

public interface CustomMapper<T> extends BaseMapper<T> {
    /**
     * 批量插入
     * @param collection 批量插入数据
     * @return ignore
     */
    int insertBatch(@Param("collection") Collection<T> collection);

    /**
     * 批量更新
     * @param collection
     * @return
     */
    int updateBatchByIds(@Param("collection") Collection<T> collection);

}

这里sql注入器要判断下是继承了该类,才进行方法扩展

@ConditionalOnExpression("${mybatis-plus.custom-mapper.enabled:true}")
@Component
public class CustomSqlInject extends DefaultSqlInjector {
    //初始化时候会加载
    @Override
    public List<AbstractMethod> getMethodList(Class<?> mapperClass) {
        List<AbstractMethod> methodList = super.getMethodList(mapperClass);
        //属于自己定义的mapper才添加方法
        if (((ParameterizedTypeImpl) mapperClass.getGenericInterfaces()[0]).getRawType().equals(CustomMapper.class)) {
            methodList.add(new InsertBatch());
            methodList.add(new UpdateBatchByIds());
        }
        if (((ParameterizedTypeImpl) mapperClass.getGenericInterfaces()[0]).getRawType().equals(ComposeKeyMapper.class)) {
            methodList.add(new SelectByComposeKeys());
            methodList.add(new SelectIdsByComposeKeys());
            methodList.add(new UpdateByComposeKeys());
            methodList.add(new InsertBatch());
        }
        return methodList;
    }
}

可以直接注入到spring里面,或者通过进行设置

        GlobalConfigUtils.getGlobalConfig(sqlSessionFactory.getConfiguration()).setSqlInjector(xx)

测试一下,定义实体,下面的typeHandler可以去掉,

@TableName(
        value = "demo"
)
@Data
@Accessors(chain = true)
public class Demo extends BaseEntity {
    private static final long serialVersionUID = 1L;
    @TableId("id")
    private Long id;

    @TableField("name")
    private String name;
    @TableField(
            value = "age",
            typeHandler = LongFormatAESEncryptHandler.class
    )
    private Long age;
    @TableField(
            value = "secret",
            typeHandler = StringFormatAESEncryptHandler.class
    )
    private String secret;
    @TableField(
            value = "CREATED_ID",
            fill = FieldFill.INSERT
    )
    private Long createdId;
    @TableField(
            value = "CREATED_BY",
            fill = FieldFill.INSERT
    )
    private String createdBy;
    @TableField(
            value = "CREATION_DATE",
            fill = FieldFill.INSERT
    )
    private Date creationDate;
    @TableField(
            value = "CREATED_BY_IP",
            fill = FieldFill.INSERT
    )
    private String createdByIp;
    @TableField(
            value = "LAST_UPDATED_ID",
            fill = FieldFill.UPDATE
    )
    private Long lastUpdatedId;
    @TableField(
            value = "LAST_UPDATED_BY",
            fill = FieldFill.UPDATE
    )
    private String lastUpdatedBy;
    @TableField(
            value = "LAST_UPDATE_DATE",
            fill = FieldFill.INSERT_UPDATE
    )
    private Date lastUpdateDate;
    @TableField(
            value = "LAST_UPDATED_BY_IP",
            fill = FieldFill.UPDATE
    )
    private String lastUpdatedByIp;


}

继承自定义mapper

public interface TempMapper extends CustomMapper<Demo> {
}

controller了的两个方法,可以支持事务

@GetMapping("/testBatch")
    @Transactional(rollbackFor = Exception.class)
    public void testBatch() {
        List<Demo> demos = new LinkedList<>();
        for (int i = 0; i < 10; i++) {
            Demo de = new Demo().setAge(Long.valueOf(i))
                    .setId(Long.valueOf(i + 1))
                    .setName("test" + i)
                    .setSecret(UUID.randomUUID().toString());
            demos.add(de);
        }
        tempMapper.insertBatch(demos);
    }

    @GetMapping("/testBatch2")
    @Transactional(rollbackFor = Exception.class)
    public void testBatch2() {
        List<Demo> demos = new LinkedList<>();
        for (int i = 0; i < 5; i++) {
            Demo de = new Demo().setAge(Long.valueOf(i))
                    .setId(Long.valueOf(i + 1))
                    .setName("test-change" + i)
                    .setSecret(UUID.randomUUID().toString());
            demos.add(de);
        }
        tempMapper.updateBatchByIds(demos);
        //int i=1/0;
    }

对于自定sql注入,还可以抽取一些公用的方法进行添加,比如逻辑删除,版本号更新,进行联合主键的批量修改和查询,下面给出联合主键的处理方案

下面是构造根据wrapper和联合主键进行查询的,核心也是定义好sql,然后进行注入,但这里需要自定义一个注解来标识哪个是联合主键

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface ComposeKey {

}
public class SelectByComposeKeys extends AbstractMethod {
    @Override
    public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        CustomSqlMethodEnum sqlMethod = CustomSqlMethodEnum.SELECT_BY_COMPOSEKEYS;
        //select #{ew.sqlSelect} from table_name where  (composekey1=xx and composekey2=xx) or (composekey1=xx1 and composekey2=xx1)
        String sqlTemplate = sqlMethod.getSql();
        String tableName = tableInfo.getTableName();
        List<TableFieldInfo> composeKeys = tableInfo.getFieldList().stream().filter(e -> e.getField().isAnnotationPresent(ComposeKey.class))
                .collect(Collectors.toList());
        if(CollectionUtils.isEmpty(composeKeys)){
            throw new ApiException("not composeKey found in class:"+modelClass.getName());
        }
        StringBuilder builder=new StringBuilder();
        for (int i = 0; i < composeKeys.size(); i++) {
            TableFieldInfo composeKey = composeKeys.get(i);
            builder.append(composeKey.getColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.ITEM + DOT + composeKey.getProperty()));
            if(i!=composeKeys.size()-1){
                builder.append(" ").append(AND).append(" ");
            }
        }
        String finalSql = String.format(sqlTemplate, tableName, builder);
        SqlSource sqlSource = languageDriver.createSqlSource(configuration, finalSql, modelClass);
        return this.addSelectMappedStatementForTable(mapperClass,sqlMethod.getMethod(),sqlSource,tableInfo);
    }
}

更新操作

public class UpdateByComposeKeys extends AbstractMethod {
    @Override
    public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        CustomSqlMethodEnum sqlMethod = CustomSqlMethodEnum.UPDATE_BY_COMPOSEKEYS;
        /**
         * update table set <if test ="item.pro!=null">key1=#{item.pro}</if>, key2=#{item.pro2} where keycolom = %s
         */
        StringBuilder withChevScript = new StringBuilder();
        StringBuilder lastFiledScriptBuilder = new StringBuilder();
        StringBuilder keyScriptBuilder = new StringBuilder();
        StringBuilder placeHolder = new StringBuilder();

        List<TableFieldInfo> fieldList = tableInfo.getFieldList();
        for (int i = 0; i < fieldList.size(); i++) {
            TableFieldInfo fieldInfo = fieldList.get(i);
            Boolean isLast = (i == (fieldList.size() - 1));
            if (fieldInfo.getField().isAnnotationPresent(ComposeKey.class)) {
                keyScriptBuilder.append(" ").append(fieldInfo.getColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.ITEM + DOT + fieldInfo.getProperty()))
                .append(" ").append(AND);
                placeHolder.append(fieldInfo.getColumn()).append(EQUALS).append(fieldInfo.getColumn()).append(COMMA);
                continue;
            }
            //有插入默认填充器或者主键的不参与赋值
            if (fieldInfo.isWithInsertFill() && !fieldInfo.isWithUpdateFill()) {
                continue;
            }
            boolean change = false;
            if (Objects.equals(fieldInfo.getUpdateStrategy(), FieldStrategy.NOT_NULL) && !fieldInfo.isWithUpdateFill()) {
                SQLConditionWrapper.appendNotNull(withChevScript, fieldInfo.getProperty());
                change = true;
            }
            if (Objects.equals(fieldInfo.getUpdateStrategy(), FieldStrategy.NOT_EMPTY) && !fieldInfo.isWithUpdateFill()) {
                SQLConditionWrapper.appendNotEmpty(withChevScript, fieldInfo.getProperty());
                change = true;
            }
            if (change) {
                withChevScript.append(fieldInfo.getColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.getCondition(fieldInfo.getProperty()).toString()));
                withChevScript.append(COMMA).append("</if>");
                if (isLast && lastFiledScriptBuilder.length() == 0) {
                    //如果没有其他字段替补,弄个占位的,不然会出现语法错误的情况
                    if (placeHolder.length() > 0) {
                        //删除多余的逗号
                        placeHolder.deleteCharAt(placeHolder.length() - 1);
                        withChevScript.append(placeHolder);
                    }

                }
            } else {
                lastFiledScriptBuilder.append(fieldInfo.getColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.getCondition(fieldInfo.getProperty()).toString())).append(COMMA);
            }
        }
        if (placeHolder.length() == 0) {
            throw new ApiException("not composeKey found in class:" + modelClass.getName());
        }
        //处理多余的逗号
        if (!StringUtils.isBlank(lastFiledScriptBuilder) && !StringUtils.isBlank(withChevScript)) {
            int leftChevIndex = withChevScript.lastIndexOf(LEFT_CHEV);
            if (withChevScript.charAt(leftChevIndex - 1) != ',') {
                withChevScript.replace(leftChevIndex, leftChevIndex + 1, ",<");
            }
            if (lastFiledScriptBuilder.lastIndexOf(COMMA) == lastFiledScriptBuilder.length() - 1) {
                lastFiledScriptBuilder.deleteCharAt(lastFiledScriptBuilder.length() - 1);
            }
        }
        if(!StringUtils.isBlank(keyScriptBuilder)){
            int i = keyScriptBuilder.lastIndexOf(AND);
            keyScriptBuilder.delete(i,i+AND.length());
        }
        withChevScript.append(lastFiledScriptBuilder);
        // sql 模板占位符替换
        String sql = String.format(sqlMethod.getSql(), tableInfo.getTableName(), withChevScript, keyScriptBuilder);
        // 模板写入
        SqlSource sqlSource = languageDriver.createSqlSource(configuration, sql, modelClass);
        return this.addUpdateMappedStatement(mapperClass, modelClass, sqlMethod.getMethod(), sqlSource);
    }
}
public class SelectIdsByComposeKeys extends AbstractMethod {
    @Override
    public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        CustomSqlMethodEnum sqlMethod = CustomSqlMethodEnum.SELECT_IDS_BY_COMPOSEKEYS;
        //select composekeys from table_name where  (composekey1=xx and composekey2=xx) or (composekey1=xx1 and composekey2=xx1)
        String sqlTemplate = sqlMethod.getSql();
        String tableName = tableInfo.getTableName();
        List<TableFieldInfo> composeKeys = tableInfo.getFieldList().stream().filter(e -> e.getField().isAnnotationPresent(ComposeKey.class))
                .collect(Collectors.toList());
        if(CollectionUtils.isEmpty(composeKeys)){
            throw new ApiException("not composeKey found in class:"+modelClass.getName());
        }
        StringBuilder builder=new StringBuilder();
        StringBuilder select=new StringBuilder();

        for (int i = 0; i < composeKeys.size(); i++) {
            TableFieldInfo composeKey = composeKeys.get(i);
            select.append(composeKey.getColumn());
            builder.append(composeKey.getColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.ITEM + DOT + composeKey.getProperty()));
            if(i!=composeKeys.size()-1){
                builder.append(" ").append(AND).append(" ");
                select.append(COMMA);
            }
        }
        String finalSql = String.format(sqlTemplate,select, tableName, builder);
        SqlSource sqlSource = languageDriver.createSqlSource(configuration, finalSql, modelClass);
        return this.addSelectMappedStatementForTable(mapperClass,sqlMethod.getMethod(),sqlSource,tableInfo);
    }
}
public class DeleteByComposeKeys extends AbstractMethod {
    @Override
    public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
        CustomSqlMethodEnum sqlMethod = CustomSqlMethodEnum.DELETE_BY_COMPOSEKEYS;
        //delete  from table_name where  (composekey1=xx and composekey2=xx) or (composekey1=xx1 and composekey2=xx1)
        String sqlTemplate = sqlMethod.getSql();
        String tableName = tableInfo.getTableName();
        List<TableFieldInfo> composeKeys = tableInfo.getFieldList().stream().filter(e -> e.getField().isAnnotationPresent(ComposeKey.class))
                .collect(Collectors.toList());
        if(CollectionUtils.isEmpty(composeKeys)){
            throw new ApiException("not composeKey found in class:"+modelClass.getName());
        }
        StringBuilder builder=new StringBuilder();
        for (int i = 0; i < composeKeys.size(); i++) {
            TableFieldInfo composeKey = composeKeys.get(i);
            builder.append(composeKey.getColumn()).append(EQUALS).append(SqlScriptUtils.safeParam(SQLConditionWrapper.ITEM + DOT + composeKey.getProperty()));
            if(i!=composeKeys.size()-1){
                builder.append(" ").append(AND).append(" ");
            }
        }
        String finalSql = String.format(sqlTemplate, tableName, builder);
        SqlSource sqlSource = languageDriver.createSqlSource(configuration, finalSql, modelClass);
        return this.addDeleteMappedStatement(mapperClass, sqlMethod.getMethod(), sqlSource);
    }

}

定义mappe类,这里对于原生的方法就不进行支持了

public interface ComposeKeyMapper<T> extends CustomMapper<T> {
    /**
     * @param params  查询的主键入参
     * @param wrapper 查询的列会用到 ew.sqlSelect
     * @return
     */
    List<T> selectByComposeKeys(@Param("collection") Collection<T> params, @Param(Constants.WRAPPER) Wrapper<T> wrapper);

    /**
     * 批量更新
     *
     * @param params
     */
    void updateByComposeKeys(@Param("collection") Collection<T> params);

    /**
     * 只查询主键
     *
     * @param params
     * @return
     */
    List<T> selectIdsByComposeKeys(@Param("collection") Collection<T> params);

    /**
     * 联合主键删除
     * @param params
     * @return
     */
    int deleteByComposeKeys(@Param("collection")Collection<T> params);

    /**
     * 下面的方法不支持
     * @param entity
     * @return
     */
    @Override
    default int updateById(T entity) {
        throw new UnsupportedOperationException();
    }

    @Override
    default int updateBatchByIds(Collection<T> collection) {
        throw new UnsupportedOperationException();
    }

    @Override
    default T selectById(Serializable id) {
        throw new UnsupportedOperationException();
    }

    @Override
    default List<T> selectBatchIds(Collection<? extends Serializable> idList) {
        throw new UnsupportedOperationException();
    }

    @Override
    default int deleteBatchIds(Collection<? extends Serializable> idList){
        throw new UnsupportedOperationException();
    }

    @Override
    default int deleteById(Serializable id){
        throw new UnsupportedOperationException();
    }
}

 

可以看到查询是按照预期给出的,也支持wrapper形式的查询

要注意,xml文件里的优先级是比自定义注入的优先级高的,这里没给出单个update,单个查询,单个删除的方法,这些都可以根据自己业务需求弄一下,项目上没用到就不写了

对于sqlSession也可以自己做一些小优化

类型有reuse,编译一次可以继续服用,使用场景:执行同一个sql,参数值不同

默认的simple是每次都会编译一下

batch相当于一个会话里面执行多条sql

reuse示例

 public static void executeReuse(Consumer<SqlSession> consumer) {
        SqlSessionFactory factory = SpringContextHolder.getBean(SqlSessionFactory.class);
        SqlSession sqlSession = factory.openSession(ExecutorType.REUSE, true);
        try {
            consumer.accept(sqlSession);
        } finally {
            sqlSession.close();
        }
    }

;