场景:
4w单的EXECL表格,导入到订单表,校验数据后落库,响应速度超过2min,大约在8min左右,这是单线程的情况下,在使用批量插入后,仍超时2min。
于是使用多线程去处理,可能出现的问题是,如果一个线程出现失败,那么应该返回响应处理失败,已插入就绪的线程任务需要回滚,但是每个线程执行的是自己的事务,那么该如何处理呢。
原因:spring中对数据库连接是放在threadLocal里面,多线程场景下,拿到的数据库连接是不一样的,即是属于不同事务。
这里提供两种解决方案:
- 使用completeFuture,获取到每个执行线程的结果,然后等待所有线程处理完,for循环对结果进行判断,如果存在FALSE则全部线程回滚。需要注意的是,这里的回滚,需要将sql自动提交关闭,并且获取插入的mapper。
- 使用切面注解,基于springboot的
@Async
注解,避免繁琐的手动提交/回滚事务
工具类:
将大批量的数据,分为多个集合,每个线程处理一个。
/**
* 拆分集合
* @param <T>
* @param resList 要拆分的集合
* @param count 每个集合的元素个数
* @return 返回拆分后的各个集合
*/
public static <T> List<List<T>> split(List<T> resList, int count) {
if (resList == null || count < 1) {
return null;
}
List<List<T>> ret = new ArrayList<List<T>>();
int size = resList.size();
if (size <= count) {
ret.add(resList);
return ret;
}
int pre = size / count;
int last = size % count;
for (int i = 0; i < pre; i++) {
List<T> itemList = new ArrayList<T>();
for (int j = 0; j < count; j++) {
itemList.add(resList.get(i * count + j));
}
ret.add(itemList);
}
if (last > 0) {
List<T> itemList = new ArrayList<T>();
for (int i = 0; i < last; i++) {
itemList.add(resList.get(pre * count + i));
}
ret.add(itemList);
}
return ret;
}
第一种:
@Service
public class StudentServiceImpl implements StudentService {
@Autowired
private SqlSessionTemplate sqlSessionTemplate;
//自定义线程池
private static final ThreadPoolExecutor THREAD_POOL_EXECUTOR = new ThreadPoolExecutor(
Runtime.getRuntime().availableProcessors() * 2,
Runtime.getRuntime().availableProcessors() * 4,
60L,
TimeUnit.SECONDS,
new LinkedBlockingDeque<>(256),
new ThreadPoolExecutor.CallerRunsPolicy()
);
@Override
@Transactional
public Result importExcel(ArrayList<StudentDto> studentDtoArrayList) throws Exception {
//例如这里有1w的数据StudentDto
//每个线程处理1k的数据
List<List<StudentDto>> insertSplit = CommonUtils.split(studentDtoArrayList, 1000);
//插入任务列表
List<CompletableFuture<Boolean>> tasks = new ArrayList<>();
//根据sqlSessionTemplate获取SqlSession工厂
SqlSessionFactory sqlSessionFactory = sqlSessionTemplate.getSqlSessionFactory();
SqlSession sqlSession = sqlSessionFactory.openSession();
//获取Connection来手动控制事务
Connection connection = sqlSession.getConnection();
try {
//关闭事务自动提交
connection.setAutoCommit(false);
//获取mapper对象
OutOrderDetailMapper studentMapper = sqlSession.getMapper(StudentMapper.class);
for (int i = 0; i < insertSplit.size(); i++) {
List<StudentDto> studentDtos = insertSplit.get(i);
//如果失败有异常,返回False,添加到自定义线程池,并将结果加入到结果集tasks
CompletableFuture<Boolean> task = CompletableFuture.supplyAsync(() -> {
try {
studentMapper.insertBatch(studentDtos);
} catch (Exception e) {
return Boolean.FALSE;
}
return Boolean.TRUE;
}, THREAD_POOL_EXECUTOR);
tasks.add(task);
}
// 等待所有任务完成
CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0])).join();
// 如果有任意一个任务抛出异常,则回滚所有数据
for (CompletableFuture<Boolean> task : tasks) {
if (Boolean.FALSE.equals(task.get())) {
log.error("插入异常,事务回滚");
connection.rollback();
return Result.error("新增失败");
}
}
//提交事务
connection.commit();
} catch (Exception e) {
connection.rollback();
e.printStackTrace();
}
return Result.ok();
}
}
第二种:
参考:https://mp.weixin.qq.com/s/BVDxoERqxXv5lJ6H85DuFg
@Configuration
@EnableAsync
@Slf4j
public class ThreadPoolTaskConfig {
@Bean("threadPoolTaskExecutor")
public ThreadPoolTaskExecutor threadPoolTaskExecutor() {
//启动类上加入@EnableAsync 开启线程池
//在子线程上使用@Async("threadPoolTaskExecutor")自定义参数名称
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(16);
threadPoolTaskExecutor.setQueueCapacity(1024);
threadPoolTaskExecutor.setMaxPoolSize(64);
threadPoolTaskExecutor.setKeepAliveSeconds(30);
threadPoolTaskExecutor.setThreadNamePrefix("自定义异步线程 - ");
threadPoolTaskExecutor.setWaitForTasksToCompleteOnShutdown(true);
threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
threadPoolTaskExecutor.initialize();
log.info("=======ThreadPoolTaskConfig========");
return threadPoolTaskExecutor;
}
}
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface MainTransaction {
/**
* 子线程数量
* @MainTransaction注解 用在调用方,其参数为必填,参数值为本方法中调用的方法开启的线程数,
* * 如:在这个方法中调用的方法中有2个方法用@Async注解开启了子线程,则参数为@MainTransaction(2)
*/
int value();
}
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface SonTransaction {
//用在被调用方(开启线程的方法),无需传入参数
String value() default "";
}
@Aspect
@Component
public class TransactionAop {
//用来存储各线程计数器数据(每次执行后会从map中删除)
private static final Map<String, Object> map = new HashMap<>();
@Resource
private PlatformTransactionManager transactionManager;
@Around("@annotation(mainTransaction)")
public void mainIntercept(ProceedingJoinPoint joinPoint, MainTransaction mainTransaction) throws Throwable {
//当前线程名称
Thread thread = Thread.currentThread();
String threadName = thread.getName();
//初始化计数器
CountDownLatch mainDownLatch = new CountDownLatch(1);
//@MainTransaction注解中的参数, 为子线程的数量
CountDownLatch sonDownLatch = new CountDownLatch(mainTransaction.value());
// 用来记录子线程的运行状态,只要有一个失败就变为true
AtomicBoolean rollBackFlag = new AtomicBoolean(false);
// 用来存每个子线程的异常,把每个线程的自定义异常向vector的首位置插入,其余异常向末位置插入,避免线程不安全,所以使用vector代替list
Vector<Throwable> exceptionVector = new Vector<>();
map.put(threadName + "mainDownLatch", mainDownLatch);
map.put(threadName + "sonDownLatch", sonDownLatch);
map.put(threadName + "rollBackFlag", rollBackFlag);
map.put(threadName + "exceptionVector", exceptionVector);
try {
//执行方法
joinPoint.proceed();
} catch (Throwable e) {
exceptionVector.add(0, e);
//子线程回滚
rollBackFlag.set(true);
//放行所有子线程
mainDownLatch.countDown();
}
if (!rollBackFlag.get()) {
try {
// sonDownLatch等待,直到所有子线程执行完插入操作,但此时还没有提交事务
sonDownLatch.await();
// 根据rollBackFlag状态放行子线程的await处,告知是回滚还是提交
mainDownLatch.countDown();
} catch (Exception e) {
rollBackFlag.set(true);
exceptionVector.add(0, e);
}
}
if (CollectionUtils.isNotEmpty(exceptionVector)) {
map.remove(threadName + "mainDownLatch");
map.remove(threadName + "sonDownLatch");
map.remove(threadName + "rollBackFlag");
map.remove(threadName + "exceptionVector");
throw exceptionVector.get(0);
}
}
@Around("@annotation(com.example.batch.annotation.SonTransaction)")
public void sonIntercept(ProceedingJoinPoint joinPoint) throws Throwable {
Object[] args = joinPoint.getArgs();
Thread thread = (Thread) args[args.length - 1];
String threadName = thread.getName();
CountDownLatch mainDownLatch = (CountDownLatch) map.get(threadName + "mainDownLatch");
if (mainDownLatch == null) {
//主事务未加注解时, 直接执行子事务
//这里最好的方式是:交由上面的thread来调用此方法,但我没有找寻到对应api,只能直接放弃事务
joinPoint.proceed();
return;
}
CountDownLatch sonDownLatch = (CountDownLatch) map.get(threadName + "sonDownLatch");
AtomicBoolean rollBackFlag = (AtomicBoolean) map.get(threadName + "rollBackFlag");
Vector<Throwable> exceptionVector = (Vector<Throwable>) map.get(threadName + "exceptionVector");
//如果这时有一个子线程已经出错,那当前线程不需要执行
if (rollBackFlag.get()) {
sonDownLatch.countDown();
return;
}
// 开启事务
DefaultTransactionDefinition def = new DefaultTransactionDefinition();
// 设置事务隔离级别
def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
TransactionStatus status = transactionManager.getTransaction(def);
try {
//执行方法
joinPoint.proceed();
// 对sonDownLatch-1
sonDownLatch.countDown();
// 如果mainDownLatch不是0,线程会在此阻塞,直到mainDownLatch变为0
mainDownLatch.await();
// 如果能执行到这一步说明所有子线程都已经执行完毕判断如果atomicBoolean是true就回滚false就提交
if (rollBackFlag.get()) {
transactionManager.rollback(status);
} else {
transactionManager.commit(status);
}
} catch (Throwable e) {
exceptionVector.add(0, e);
// 回滚
transactionManager.rollback(status);
// 并把状态设置为true
rollBackFlag.set(true);
mainDownLatch.countDown();
sonDownLatch.countDown();
}
}
}
测试:
需要注意的是,当业务逻辑复杂的时候,例如多个if语句,每个if语句开启子线程的数量不同,可以放弃使用@MainTransaction避免锁表
@Service
@Slf4j
public class RewriteSqlServiceImpl implements RewriteSqlService {
@Autowired
private RewriteSqlMapper rewriteSqlMapper;
@Override
@Transactional(rollbackFor = Exception.class)
@Async("threadPoolTaskExecutor")
@SonTransaction
public void threadInsert(List<RewriteSqlDO> user, Thread thread) throws Exception {
log.error("子线程1启动");
int insertBatch = rewriteSqlMapper.insertBatch(user);
if(insertBatch == 0){
throw new Exception("error");
}
}
@Override
@Transactional(rollbackFor = Exception.class)
@Async("threadPoolTaskExecutor")
@SonTransaction
public void threadUpdate(Integer id, Thread thread) throws Exception {
log.error("子线程2启动");
int update = rewriteSqlMapper.update(null, new LambdaUpdateWrapper<RewriteSqlDO>()
.eq(RewriteSqlDO::getId, id)
.set(RewriteSqlDO::getStuName, "llllll")
);
if(update == 0){
throw new Exception("error");
}
}
@Override
@Transactional(rollbackFor = Exception.class)
public void threadMain() throws Exception {
log.error("主线程启动");
List<RewriteSqlDO> rewriteSqlDOList = new ArrayList<>();
for (int i = 0; i < 10000; i++) {
RewriteSqlDO rewriteSqlDO = new RewriteSqlDO();
rewriteSqlDO.setId(i);
rewriteSqlDO.setAge(i);
rewriteSqlDO.setStuName(String.valueOf(i));
rewriteSqlDOList.add(rewriteSqlDO);
}
threadInsert(rewriteSqlDOList, Thread.currentThread());
threadUpdate(1, Thread.currentThread());
}
}