Bootstrap

多线程优化大量数据如何事务回滚

场景:

4w单的EXECL表格,导入到订单表,校验数据后落库,响应速度超过2min,大约在8min左右,这是单线程的情况下,在使用批量插入后,仍超时2min。

于是使用多线程去处理,可能出现的问题是,如果一个线程出现失败,那么应该返回响应处理失败,已插入就绪的线程任务需要回滚,但是每个线程执行的是自己的事务,那么该如何处理呢。

原因:spring中对数据库连接是放在threadLocal里面,多线程场景下,拿到的数据库连接是不一样的,即是属于不同事务。

这里提供两种解决方案:

  • 使用completeFuture,获取到每个执行线程的结果,然后等待所有线程处理完,for循环对结果进行判断,如果存在FALSE则全部线程回滚。需要注意的是,这里的回滚,需要将sql自动提交关闭,并且获取插入的mapper。
  • 使用切面注解,基于springboot的@Async注解,避免繁琐的手动提交/回滚事务

工具类:

将大批量的数据,分为多个集合,每个线程处理一个。

    /**
     * 拆分集合
     * @param <T>
     * @param resList  要拆分的集合
     * @param count    每个集合的元素个数
     * @return  返回拆分后的各个集合
     */
    public static <T> List<List<T>> split(List<T> resList, int count) {
        if (resList == null || count < 1) {
            return null;
        }
        List<List<T>> ret = new ArrayList<List<T>>();
        int size = resList.size();
        if (size <= count) {
            ret.add(resList);
            return ret;
        }
        int pre = size / count;
        int last = size % count;
        for (int i = 0; i < pre; i++) {
            List<T> itemList = new ArrayList<T>();
            for (int j = 0; j < count; j++) {
                itemList.add(resList.get(i * count + j));
            }
            ret.add(itemList);
        }
        if (last > 0) {
            List<T> itemList = new ArrayList<T>();
            for (int i = 0; i < last; i++) {
                itemList.add(resList.get(pre * count + i));
            }
            ret.add(itemList);
        }
        return ret;
    }

第一种:

@Service
public class StudentServiceImpl implements StudentService {
    @Autowired
    private SqlSessionTemplate sqlSessionTemplate;
    
    //自定义线程池
    private static final ThreadPoolExecutor THREAD_POOL_EXECUTOR = new ThreadPoolExecutor(
            Runtime.getRuntime().availableProcessors() * 2,
            Runtime.getRuntime().availableProcessors() * 4, 
            60L, 
            TimeUnit.SECONDS, 
            new LinkedBlockingDeque<>(256),
            new ThreadPoolExecutor.CallerRunsPolicy()
    );

    @Override
    @Transactional
    public Result importExcel(ArrayList<StudentDto> studentDtoArrayList) throws Exception {
        //例如这里有1w的数据StudentDto

        //每个线程处理1k的数据
        List<List<StudentDto>> insertSplit = CommonUtils.split(studentDtoArrayList, 1000);

        //插入任务列表
        List<CompletableFuture<Boolean>> tasks = new ArrayList<>();
        //根据sqlSessionTemplate获取SqlSession工厂
        SqlSessionFactory sqlSessionFactory = sqlSessionTemplate.getSqlSessionFactory();
        SqlSession sqlSession = sqlSessionFactory.openSession();
        //获取Connection来手动控制事务
        Connection connection = sqlSession.getConnection();

        try {
            //关闭事务自动提交
            connection.setAutoCommit(false);

            //获取mapper对象
            OutOrderDetailMapper studentMapper = sqlSession.getMapper(StudentMapper.class);

            for (int i = 0; i < insertSplit.size(); i++) {
                List<StudentDto> studentDtos = insertSplit.get(i);
                //如果失败有异常,返回False,添加到自定义线程池,并将结果加入到结果集tasks
                CompletableFuture<Boolean> task = CompletableFuture.supplyAsync(() -> {
                    try {
                        studentMapper.insertBatch(studentDtos);
                    } catch (Exception e) {
                        return Boolean.FALSE;
                    }
                    return Boolean.TRUE;
                }, THREAD_POOL_EXECUTOR);
                tasks.add(task);
            }

            // 等待所有任务完成
            CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join();

            // 如果有任意一个任务抛出异常,则回滚所有数据
            for (CompletableFuture<Boolean> task : tasks) {
                if (Boolean.FALSE.equals(task.get())) {
                    log.error("插入异常,事务回滚");
                    connection.rollback();
                    return Result.error("新增失败");
                }
            }
            
            //提交事务
            connection.commit();
        } catch (Exception e) {
            connection.rollback();
            e.printStackTrace();
        }
        return Result.ok();
    }
}

第二种:

参考:https://mp.weixin.qq.com/s/BVDxoERqxXv5lJ6H85DuFg

@Configuration
@EnableAsync
@Slf4j
public class ThreadPoolTaskConfig {
    @Bean("threadPoolTaskExecutor")
    public ThreadPoolTaskExecutor threadPoolTaskExecutor() {
        //启动类上加入@EnableAsync 开启线程池
        //在子线程上使用@Async("threadPoolTaskExecutor")自定义参数名称
        ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
        threadPoolTaskExecutor.setCorePoolSize(16);
        threadPoolTaskExecutor.setQueueCapacity(1024);
        threadPoolTaskExecutor.setMaxPoolSize(64);
        threadPoolTaskExecutor.setKeepAliveSeconds(30);
        threadPoolTaskExecutor.setThreadNamePrefix("自定义异步线程 - ");
        threadPoolTaskExecutor.setWaitForTasksToCompleteOnShutdown(true);
        threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        threadPoolTaskExecutor.initialize();
        log.info("=======ThreadPoolTaskConfig========");
        return threadPoolTaskExecutor;
    }
}
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface MainTransaction {

    /**
     * 子线程数量
     * @MainTransaction注解 用在调用方,其参数为必填,参数值为本方法中调用的方法开启的线程数,
     *  * 如:在这个方法中调用的方法中有2个方法用@Async注解开启了子线程,则参数为@MainTransaction(2)
     */
    int value();

}
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface SonTransaction {

    //用在被调用方(开启线程的方法),无需传入参数
    String value() default "";

}
@Aspect
@Component
public class TransactionAop {

    //用来存储各线程计数器数据(每次执行后会从map中删除)
    private static final Map<String, Object> map = new HashMap<>();

    @Resource
    private PlatformTransactionManager transactionManager;

    @Around("@annotation(mainTransaction)")
    public void mainIntercept(ProceedingJoinPoint joinPoint, MainTransaction mainTransaction) throws Throwable {
        //当前线程名称
        Thread thread = Thread.currentThread();
        String threadName = thread.getName();
        //初始化计数器
        CountDownLatch mainDownLatch = new CountDownLatch(1);
        //@MainTransaction注解中的参数, 为子线程的数量
        CountDownLatch sonDownLatch = new CountDownLatch(mainTransaction.value());
        // 用来记录子线程的运行状态,只要有一个失败就变为true
        AtomicBoolean rollBackFlag = new AtomicBoolean(false);
        // 用来存每个子线程的异常,把每个线程的自定义异常向vector的首位置插入,其余异常向末位置插入,避免线程不安全,所以使用vector代替list
        Vector<Throwable> exceptionVector = new Vector<>();

        map.put(threadName + "mainDownLatch", mainDownLatch);
        map.put(threadName + "sonDownLatch", sonDownLatch);
        map.put(threadName + "rollBackFlag", rollBackFlag);
        map.put(threadName + "exceptionVector", exceptionVector);

        try {
            //执行方法
            joinPoint.proceed();
        } catch (Throwable e) {
            exceptionVector.add(0, e);
            //子线程回滚
            rollBackFlag.set(true);
            //放行所有子线程
            mainDownLatch.countDown();
        }

        if (!rollBackFlag.get()) {
            try {
                // sonDownLatch等待,直到所有子线程执行完插入操作,但此时还没有提交事务
                sonDownLatch.await();
                // 根据rollBackFlag状态放行子线程的await处,告知是回滚还是提交
                mainDownLatch.countDown();
            } catch (Exception e) {
                rollBackFlag.set(true);
                exceptionVector.add(0, e);
            }
        }
        if (CollectionUtils.isNotEmpty(exceptionVector)) {
            map.remove(threadName + "mainDownLatch");
            map.remove(threadName + "sonDownLatch");
            map.remove(threadName + "rollBackFlag");
            map.remove(threadName + "exceptionVector");
            throw exceptionVector.get(0);
        }
    }

    @Around("@annotation(com.example.batch.annotation.SonTransaction)")
    public void sonIntercept(ProceedingJoinPoint joinPoint) throws Throwable {
        Object[] args = joinPoint.getArgs();
        Thread thread = (Thread) args[args.length - 1];
        String threadName = thread.getName();
        CountDownLatch mainDownLatch = (CountDownLatch) map.get(threadName + "mainDownLatch");
        if (mainDownLatch == null) {
            //主事务未加注解时, 直接执行子事务
            //这里最好的方式是:交由上面的thread来调用此方法,但我没有找寻到对应api,只能直接放弃事务
            joinPoint.proceed();
            return;
        }
        CountDownLatch sonDownLatch = (CountDownLatch) map.get(threadName + "sonDownLatch");
        AtomicBoolean rollBackFlag = (AtomicBoolean) map.get(threadName + "rollBackFlag");
        Vector<Throwable> exceptionVector = (Vector<Throwable>) map.get(threadName + "exceptionVector");

        //如果这时有一个子线程已经出错,那当前线程不需要执行
        if (rollBackFlag.get()) {
            sonDownLatch.countDown();
            return;
        }

        // 开启事务
        DefaultTransactionDefinition def = new DefaultTransactionDefinition();
        // 设置事务隔离级别
        def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
        TransactionStatus status = transactionManager.getTransaction(def);

        try {
            //执行方法
            joinPoint.proceed();
            // 对sonDownLatch-1
            sonDownLatch.countDown();
            // 如果mainDownLatch不是0,线程会在此阻塞,直到mainDownLatch变为0
            mainDownLatch.await();
            // 如果能执行到这一步说明所有子线程都已经执行完毕判断如果atomicBoolean是true就回滚false就提交
            if (rollBackFlag.get()) {
                transactionManager.rollback(status);
            } else {
                transactionManager.commit(status);
            }
        } catch (Throwable e) {
            exceptionVector.add(0, e);
            // 回滚
            transactionManager.rollback(status);
            // 并把状态设置为true
            rollBackFlag.set(true);
            mainDownLatch.countDown();
            sonDownLatch.countDown();
        }
    }
}

测试:

需要注意的是,当业务逻辑复杂的时候,例如多个if语句,每个if语句开启子线程的数量不同,可以放弃使用@MainTransaction避免锁表

@Service
@Slf4j
public class RewriteSqlServiceImpl implements RewriteSqlService {
    @Autowired
    private RewriteSqlMapper rewriteSqlMapper;


    @Override
    @Transactional(rollbackFor = Exception.class)
    @Async("threadPoolTaskExecutor")
    @SonTransaction
    public void threadInsert(List<RewriteSqlDO> user, Thread thread) throws Exception {
        log.error("子线程1启动");
        int insertBatch = rewriteSqlMapper.insertBatch(user);
        if(insertBatch == 0){
            throw new Exception("error");
        }
    }

    @Override
    @Transactional(rollbackFor = Exception.class)
    @Async("threadPoolTaskExecutor")
    @SonTransaction
    public void threadUpdate(Integer id, Thread thread) throws Exception {
        log.error("子线程2启动");
        int update = rewriteSqlMapper.update(null, new LambdaUpdateWrapper<RewriteSqlDO>()
                .eq(RewriteSqlDO::getId, id)
                .set(RewriteSqlDO::getStuName, "llllll")
        );
        if(update == 0){
            throw new Exception("error");
        }
    }

    @Override
    @Transactional(rollbackFor = Exception.class)
    public void threadMain() throws Exception {
        log.error("主线程启动");
        List<RewriteSqlDO> rewriteSqlDOList = new ArrayList<>();
        for (int i = 0; i < 10000; i++) {
            RewriteSqlDO rewriteSqlDO = new RewriteSqlDO();
            rewriteSqlDO.setId(i);
            rewriteSqlDO.setAge(i);
            rewriteSqlDO.setStuName(String.valueOf(i));
            rewriteSqlDOList.add(rewriteSqlDO);
        }
        threadInsert(rewriteSqlDOList, Thread.currentThread());
        threadUpdate(1, Thread.currentThread());
    }
}
;