Bootstrap

Mybatis Plus快速重构真批量sql入库操作

Mybatis快速重构真批量sql入库操作

基本思路

重构mybatis默认方法saveBatch和saveOrUpdateBatch的实现

基本步骤
  1. 真批量保存实现类InsertBatchMethod
  2. 真批量更新实现类MysqlInsertOrUpdateBath
  3. 注册InsertBatchMethod和MysqlInsertOrUpdateBath到EasySqlInjector
  4. 注册EasySqlInjector到Mybatis配置类
  5. 实现自定义RootMapper
  6. 要实现真批量操作的实体类的Mapper继承RootMapper
  7. 自定义MyServiceImpl实现类
  8. 实体类ServiceImpl继承MyServiceImpl
  9. 批量方法的使用
详细步骤
  1. 真批量保存实现类InsertBatchMethod

    import com.baomidou.mybatisplus.core.injector.AbstractMethod;
    import com.baomidou.mybatisplus.core.metadata.TableInfo;
    import lombok.extern.slf4j.Slf4j;
    import org.apache.ibatis.executor.keygen.NoKeyGenerator;
    import org.apache.ibatis.mapping.MappedStatement;
    import org.apache.ibatis.mapping.SqlSource;
    
    /**
     * 批量插入方法实现
     */
    @Slf4j
    public class InsertBatchMethod extends AbstractMethod {
      
        @Override
        public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
            final String sql = "<script>insert into %s %s values %s</script>";
            final String fieldSql = prepareFieldSql(tableInfo);
            final String valueSql = prepareValuesSql(tableInfo);
            final String sqlResult = String.format(sql, tableInfo.getTableName(), fieldSql, valueSql);
            log.debug("sqlResult----->{}", sqlResult);
            SqlSource sqlSource = languageDriver.createSqlSource(configuration, sqlResult, modelClass);
            // 第三个参数必须和RootMapper的自定义方法名一致
            return this.addInsertMappedStatement(mapperClass, modelClass, "insertBatch", sqlSource, new NoKeyGenerator(), null, null);
        }
    
        private String prepareFieldSql(TableInfo tableInfo) {
            StringBuilder fieldSql = new StringBuilder();
            fieldSql.append(tableInfo.getKeyColumn()).append(",");
            tableInfo.getFieldList().forEach(x -> {
                fieldSql.append(x.getColumn()).append(",");
            });
            fieldSql.delete(fieldSql.length() - 1, fieldSql.length());
            fieldSql.insert(0, "(");
            fieldSql.append(")");
            return fieldSql.toString();
        }
    
        private String prepareValuesSql(TableInfo tableInfo) {
            final StringBuilder valueSql = new StringBuilder();
            valueSql.append("<foreach collection=\"list\" item=\"item\" index=\"index\" open=\"(\" separator=\"),(\" close=\")\">");
            valueSql.append("#{item.").append(tableInfo.getKeyProperty()).append("},");
            tableInfo.getFieldList().forEach(x -> valueSql.append("#{item.").append(x.getProperty()).append("},"));
            valueSql.delete(valueSql.length() - 1, valueSql.length());
            valueSql.append("</foreach>");
            return valueSql.toString();
        }
    }
    
  2. 真批量更新实现类MysqlInsertOrUpdateBath

    import com.baomidou.mybatisplus.core.injector.AbstractMethod;
    import com.baomidou.mybatisplus.core.metadata.TableInfo;
    import org.apache.ibatis.executor.keygen.NoKeyGenerator;
    import org.apache.ibatis.mapping.MappedStatement;
    import org.apache.ibatis.mapping.SqlSource;
    import org.springframework.util.StringUtils;
    
    public class MysqlInsertOrUpdateBath extends AbstractMethod {
    
        @Override
        public MappedStatement injectMappedStatement(Class<?> mapperClass, Class<?> modelClass, TableInfo tableInfo) {
            final String sql = "<script>insert into %s %s values %s ON DUPLICATE KEY UPDATE %s</script>";
            final String tableName = tableInfo.getTableName();
            final String filedSql = prepareFieldSql(tableInfo);
            final String modelValuesSql = prepareModelValuesSql(tableInfo);
            final String duplicateKeySql =prepareDuplicateKeySql(tableInfo);
            final String sqlResult = String.format(sql, tableName, filedSql, modelValuesSql,duplicateKeySql);
            //System.out.println("savaorupdatesqlsql="+sqlResult);
            SqlSource sqlSource = languageDriver.createSqlSource(configuration, sqlResult, modelClass);
            return this.addInsertMappedStatement(mapperClass, modelClass, "mysqlInsertOrUpdateBath", sqlSource, new NoKeyGenerator(), null, null);
        }
    
        /**
         * 准备ON DUPLICATE KEY UPDATE sql
         * @param tableInfo
         * @return
         */
        private String prepareDuplicateKeySql(TableInfo tableInfo) {
            final StringBuilder duplicateKeySql = new StringBuilder();
            if(!StringUtils.isEmpty(tableInfo.getKeyColumn())) {
                duplicateKeySql.append(tableInfo.getKeyColumn()).append("=values(").append(tableInfo.getKeyColumn()).append("),");
            }
    
            tableInfo.getFieldList().forEach(x -> {
                duplicateKeySql.append(x.getColumn())
                        .append("=values(")
                        .append(x.getColumn())
                        .append("),");
            });
            duplicateKeySql.delete(duplicateKeySql.length() - 1, duplicateKeySql.length());
            return duplicateKeySql.toString();
        }
    
        /**
         * 准备属性名
         * @param tableInfo
         * @return
         */
        private String prepareFieldSql(TableInfo tableInfo) {
            StringBuilder fieldSql = new StringBuilder();
            fieldSql.append(tableInfo.getKeyColumn()).append(",");
            tableInfo.getFieldList().forEach(x -> {
                fieldSql.append(x.getColumn()).append(",");
            });
            fieldSql.delete(fieldSql.length() - 1, fieldSql.length());
            fieldSql.insert(0, "(");
            fieldSql.append(")");
            return fieldSql.toString();
        }
    
        private String prepareModelValuesSql(TableInfo tableInfo){
            final StringBuilder valueSql = new StringBuilder();
            valueSql.append("<foreach collection=\"list\" item=\"item\" index=\"index\" open=\"(\" separator=\"),(\" close=\")\">");
            if(!StringUtils.isEmpty(tableInfo.getKeyProperty())) {
                valueSql.append("#{item.").append(tableInfo.getKeyProperty()).append("},");
            }
            tableInfo.getFieldList().forEach(x -> valueSql.append("#{item.").append(x.getProperty()).append("},"));
            valueSql.delete(valueSql.length() - 1, valueSql.length());
            valueSql.append("</foreach>");
            return valueSql.toString();
        }
    }
    
  3. 注册InsertBatchMethod和MysqlInsertOrUpdateBath到EasySqlInjector

    import com.baomidou.mybatisplus.core.injector.AbstractMethod;
    import com.baomidou.mybatisplus.core.injector.DefaultSqlInjector;
    import com.baomidou.mybatisplus.extension.injector.methods.AlwaysUpdateSomeColumnById;
    import com.baomidou.mybatisplus.extension.injector.methods.InsertBatchSomeColumn;
    import org.springframework.stereotype.Component;
    import java.util.List;
    
    
    @Component
    public class EasySqlInjector extends DefaultSqlInjector {
    
        @Override
        public List<AbstractMethod> getMethodList(Class<?> mapperClass) {
            List<AbstractMethod> methodList = super.getMethodList(mapperClass);
            // 真批量插入接口
            methodList.add(new InsertBatchSomeColumn());
            // 总是更新字段,不忽略null
            methodList.add(new AlwaysUpdateSomeColumnById(tableFieldInfo ->
                    !tableFieldInfo.getColumn().equals("DELETED")
                    && !tableFieldInfo.getColumn().equals("ID")
                    && !tableFieldInfo.getColumn().equals("CREATE_TIME")
                    && !tableFieldInfo.getColumn().equals("CREATE_USER")
                    && !tableFieldInfo.getColumn().equals("CREATE_USER_NAME")));
            // 真批量更新
            methodList.add(new InsertBatchMethod());
            methodList.add(new MysqlInsertOrUpdateBath());
            return methodList;
        }
    }
    
  4. 注册EasySqlInjector到Mybatis配置类

    @Slf4j
    @Configuration
    @EnableNacosConfig
    @EnableConfigurationProperties
    @MapperScan("com.**.dao")
    public class OPAConfiguration implements EnvironmentAware,BeanFactoryAware{
    
    @Bean
        public xxxMybatisContext xxxMybatisContext(xxxProperties xxxProperties, MybatisPlusProperties properties) throws Exception {
        	xxxMybatisContext xxxMybatisContext=new xxxMybatisContext();
    
            // TODO 使用 MybatisSqlSessionFactoryBean 而不是 SqlSessionFactoryBean
            MybatisSqlSessionFactoryBean factory = new MybatisSqlSessionFactoryBean();
            factory.setDataSource(this.dataSource);
            factory.setVfs(SpringBootVFS.class);
            factory.setMapperLocations(new PathMatchingResourcePatternResolver().getResources("classpath*:com/xxx/xxx/xxx/mapping/*Mapper.xml"));
            // TODO 此处必为非 NULL
            GlobalConfig globalConfig = properties.getGlobalConfig();
            GlobalConfig newGlobalConfig=new GlobalConfig();
            BeanUtils.copyProperties(globalConfig, newGlobalConfig);
            newGlobalConfig.setMetaObjectHandler(new xxxMetaObjectHandler());
    		newGlobalConfig.setSqlInjector(new EasySqlInjector());
            factory.setGlobalConfig(newGlobalConfig);
            SqlSessionFactory sqlSessionFactory=factory.getObject();
    
            newGlobalConfig.setSqlSessionFactory(sqlSessionFactory);
            xxxMybatisContext.setSessionFactory(sqlSessionFactory);
            xxxMybatisContext.setSqlSessionTemplate(new SqlSessionTemplate(sqlSessionFactory));
            ClassPathMapperScanner scanner = new ClassPathMapperScanner((BeanDefinitionRegistry) beanFactory);
            scanner.setSqlSessionTemplate(xxxMybatisContext.getSqlSessionTemplate());
         //   scanner.setBeanNameGenerator(new xxxBeanNameGenerator());
            scanner.registerFilters();
            scanner.scan(
                StringUtils.tokenizeToStringArray("com.xxx.xxx.xxx.mapping", ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS));
    
        	return xxxMybatisContext;
        }
        
    }    
    
  5. 实现自定义RootMapper

    import com.baomidou.mybatisplus.core.mapper.BaseMapper;
    import org.apache.ibatis.annotations.Param;
    import java.util.List;
    
    
    /**
     * 根Mapper,给表Mapper继承用的,可以自定义通用方法
     * {@link BaseMapper}
     * {@link com.baomidou.mybatisplus.extension.service.IService}
     * {@link com.baomidou.mybatisplus.extension.service.impl.ServiceImpl}
     */
    public interface RootMapper<T> extends BaseMapper<T> {
    
        /**
         * 自定义批量插入
         * 如果要自动填充,@Param(xx) xx参数名必须是 list/collection/array 3个的其中之一
         */
        int insertBatch(@Param("list") List<T> list);
    
        /**
         * 自定义批量新增或更新
         * 如果要自动填充,@Param(xx) xx参数名必须是 list/collection/array 3个的其中之一
         */
        int mysqlInsertOrUpdateBath(@Param("list") List<T> list);
    
    
    }
    
  6. 要实现真批量操作的实体类的Mapper继承RootMapper

    @Mapper
    public interface BudgetVSActualMapper extends RootMapper<BudgetVSActual> {
    
  7. 自定义MyServiceImpl实现类

    import com.baomidou.mybatisplus.core.conditions.Wrapper;
    import com.baomidou.mybatisplus.core.enums.SqlMethod;
    import com.baomidou.mybatisplus.core.metadata.TableInfo;
    import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
    import com.baomidou.mybatisplus.core.toolkit.*;
    import com.baomidou.mybatisplus.extension.service.IService;
    import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
    import com.pwc.sdc.OPA.mapping.RootMapper;
    import org.apache.ibatis.binding.MapperMethod;
    import org.apache.ibatis.logging.Log;
    import org.apache.ibatis.logging.LogFactory;
    import org.apache.ibatis.session.SqlSession;
    import org.mybatis.spring.SqlSessionUtils;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.core.ResolvableType;
    import org.springframework.transaction.annotation.Transactional;
    import java.io.Serializable;
    import java.util.*;
    import java.util.function.BiConsumer;
    import java.util.function.Consumer;
    import java.util.function.Function;
    import java.util.stream.Collectors;
    import java.util.stream.Stream;
    
    
    public  class MyServiceImpl<M extends RootMapper<T>, T> implements IService<T> {
        protected Log log = LogFactory.getLog(this.getClass());
        @Autowired
        protected M baseMapper;
        protected Class<T> entityClass = this.currentModelClass();
        protected Class<T> mapperClass = this.currentMapperClass();
    
        public MyServiceImpl() {
        }
    
        public M getBaseMapper() {
            return this.baseMapper;
        }
    
        public Class<T> getEntityClass() {
            return this.entityClass;
        }
    
        /** @deprecated */
        @Deprecated
        protected boolean retBool(Integer result) {
            return SqlHelper.retBool(result);
        }
    
        protected Class<T> currentMapperClass() {
            return (Class)this.getResolvableType().as(MyServiceImpl.class).getGeneric(new int[]{0}).getType();
        }
    
        protected Class<T> currentModelClass() {
            return (Class)this.getResolvableType().as(MyServiceImpl.class).getGeneric(new int[]{1}).getType();
        }
    
        protected ResolvableType getResolvableType() {
            return ResolvableType.forClass(ClassUtils.getUserClass(this.getClass()));
        }
    
        /** @deprecated */
        @Deprecated
        protected SqlSession sqlSessionBatch() {
            return SqlHelper.sqlSessionBatch(this.entityClass);
        }
    
        /** @deprecated */
        @Deprecated
        protected void closeSqlSession(SqlSession sqlSession) {
            SqlSessionUtils.closeSqlSession(sqlSession, GlobalConfigUtils.currentSessionFactory(this.entityClass));
        }
    
        /** @deprecated */
        @Deprecated
        protected String sqlStatement(SqlMethod sqlMethod) {
            return SqlHelper.table(this.entityClass).getSqlStatement(sqlMethod.getMethod());
        }
    
        @Transactional(
                rollbackFor = {Exception.class}
        )
        public boolean saveBatch(Collection<T> entityList, int batchSize) {
            if(CollectionUtils.isEmpty(entityList)){
                return true;
            }
            int maxNum=2000;
            int limit = countStep(entityList.size(), maxNum);
            List<List<T>> subList = new ArrayList<>();
            Stream.iterate(0, n -> n + 1).limit(limit).forEach(i -> {
                subList.add(entityList.stream().skip(i * maxNum).limit(maxNum).collect(Collectors.toList()));
            });
            for (List<T> t : subList) {
                this.baseMapper.insertBatch(t);
            }
            return true;
        }
    
        protected String getSqlStatement(SqlMethod sqlMethod) {
            return SqlHelper.getSqlStatement(this.mapperClass, sqlMethod);
        }
    
        @Transactional(
                rollbackFor = {Exception.class}
        )
        public boolean saveOrUpdate(T entity) {
            if (null == entity) {
                return false;
            } else {
                TableInfo tableInfo = TableInfoHelper.getTableInfo(this.entityClass);
                Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!", new Object[0]);
                String keyProperty = tableInfo.getKeyProperty();
                Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!", new Object[0]);
                Object idVal = ReflectionKit.getFieldValue(entity, tableInfo.getKeyProperty());
                return !StringUtils.checkValNull(idVal) && !Objects.isNull(this.getById((Serializable)idVal)) ? this.updateById(entity) : this.save(entity);
            }
        }
    
        @Transactional(
                rollbackFor = {Exception.class}
        )
        public boolean saveOrUpdateBatch(Collection<T> entityList, int batchSize) {
            if(CollectionUtils.isEmpty(entityList)){
                return true;
            }
            int maxNum=2000;
            int limit = countStep(entityList.size(), maxNum);
            List<List<T>> subList = new ArrayList<>();
            Stream.iterate(0, n -> n + 1).limit(limit).forEach(i -> {
                subList.add(entityList.stream().skip(i * maxNum).limit(maxNum).collect(Collectors.toList()));
            });
            for (List<T> t : subList) {
                this.baseMapper.mysqlInsertOrUpdateBath(t);
            }
            return true;
        }
        private static Integer countStep(Integer size, Integer maxNum) {
            return (size + maxNum - 1) / maxNum;
        }
    
    
        @Transactional(
                rollbackFor = {Exception.class}
        )
        public boolean updateBatchById(Collection<T> entityList, int batchSize) {
            String sqlStatement = this.getSqlStatement(SqlMethod.UPDATE_BY_ID);
            return this.executeBatch(entityList, batchSize, (sqlSession, entity) -> {
                MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap();
                param.put("et", entity);
                sqlSession.update(sqlStatement, param);
            });
        }
    
        public T getOne(Wrapper<T> queryWrapper, boolean throwEx) {
            return throwEx ? this.baseMapper.selectOne(queryWrapper) : SqlHelper.getObject(this.log, this.baseMapper.selectList(queryWrapper));
        }
    
        public Map<String, Object> getMap(Wrapper<T> queryWrapper) {
            return (Map)SqlHelper.getObject(this.log, this.baseMapper.selectMaps(queryWrapper));
        }
    
        public <V> V getObj(Wrapper<T> queryWrapper, Function<? super Object, V> mapper) {
            return SqlHelper.getObject(this.log, this.listObjs(queryWrapper, mapper));
        }
    
        /** @deprecated */
        @Deprecated
        protected boolean executeBatch(Consumer<SqlSession> consumer) {
            return SqlHelper.executeBatch(this.entityClass, this.log, consumer);
        }
    
        protected <E> boolean executeBatch(Collection<E> list, int batchSize, BiConsumer<SqlSession, E> consumer) {
            return SqlHelper.executeBatch(this.entityClass, this.log, list, batchSize, consumer);
        }
    
        protected <E> boolean executeBatch(Collection<E> list, BiConsumer<SqlSession, E> consumer) {
            return this.executeBatch(list, 1000, consumer);
        }
    
  8. 实体类ServiceImpl继承MyServiceImpl

    public class BudgetVSActualServiceImpl extends MyServiceImpl<BudgetVSActualMapper, BudgetVSActual> implements BudgetVSActualService {
    
  9. 批量方法的使用

    @Autowired
    private IService<T> iService;
    
    //批量保存数据
    iService.saveBatch(DataList);
    //批量更新数据
    iService.saveOrUpdateBatch(DataList);
    
;