前言
利用线程池和CountDownLatch,多线程并发处理批量数据,实现多线程事务回滚,事务补偿。
//定义两计数器
private CountDownLatch begin,end;
begin设置为1,用于发布开始命令,如果需要开始,则begin.countdown
end用于记录任务的执行情况。begin.countdown后,需end.await,等待任务都执行完。
当begin.countdown开始执行任务后,在最后需end.countdown
当end.countdown减到为0后,则切换到主线程,继续开始往下执行
基于回调函数
实现更灵活的去配置各业务数据操作场景,即:暴露excute方法执行线程任务,执行的具体执行任务交给回调函数实现。
基于spring上下文中获取事务管理器
封装获取spring上下文工具类
ApplicationContextProvider
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;
/**
* @Author by mocar小师兄
* @DESC: 从已有的spring上下文取得已实例化的bean
*/
@Component
public class ApplicationContextProvider implements ApplicationContextAware {
private static final Logger log = LoggerFactory.getLogger(ApplicationContextProvider.class);
private static ApplicationContext applicationContext;
/**
* 设置spring上下文
* @param applicationContext spring上下文
* @throws BeansException
*/
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
log.info("spring上下文applicationContext正在初始化,application:{}" ,applicationContext);
this.applicationContext = applicationContext;
log.info("spring上下文applicationContext初始化完成!");
}
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
public static Object getBean(String name){
if(applicationContext==null){
log.warn("applicationContext是空的");
return null;
}
return getApplicationContext().getBean(name);
}
public static <T> T getBean(Class<T> clazz){
return getApplicationContext().getBean(clazz);
}
}
封装的工具类
package com.example.javademo.transaction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
public class TransactionMultipartExecutor<T> {
private static final Logger log = LoggerFactory.getLogger(TransactionMultipartExecutor.class);
/**
* 单个线程处理的数据量
*/
private int singleCount;
/**
* 处理的总数据量
*/
private int listSize;
/**
* 开启的线程数
*/
private int runSize;
/**
* 操作的数据集
*/
private List<T> list;
/**
* 计数器(拦截器)
*/
private CountDownLatch begin, end;
/**
* 线程池
*/
private ExecutorService executorService;
/**
* 是否存在异常
*/
private AtomicReference<Boolean> isError = new AtomicReference<>(false);
/**
* 回调函数
*/
private CallBack callBack;
/**
* 概率模拟报错
*/
private Random random = new Random();
/**
* 事务管理器
*/
private PlatformTransactionManager transactionManager;
public void setCallBack(CallBack callBack) {
this.callBack = callBack;
}
public TransactionMultipartExecutor(int singleCount, List<T> list) {
if (singleCount <= 0 || CollectionUtils.isEmpty(list)){
throw new RuntimeException("Illegal parameter");
}
//transactionManager = ContextLoader.getCurrentWebApplicationContext().getBean(PlatformTransactionManager.class);
transactionManager = ApplicationContextProvider.getBean(PlatformTransactionManager.class);
this.singleCount = singleCount;
this.list = list;
this.listSize = list.size();
this.runSize = (this.listSize%this.singleCount)==0 ? this.listSize/this.singleCount : this.listSize/this.singleCount + 1;
}
public void excute() throws InterruptedException {
// 创建固定线程数量的线程池
executorService = Executors.newFixedThreadPool(runSize);
begin = new CountDownLatch(1);
end = new CountDownLatch(runSize);
//创建线程
int startIndex = 0;
int endIndex = 0;
List<T> newList = null;
for (int i = 0; i < runSize; i++) {
//计算每个线程对应的数据
if (i < (runSize - 1)) {
startIndex = i * singleCount;
endIndex = (i + 1) * singleCount;
newList = list.subList(startIndex, endIndex);
} else {
startIndex = i * singleCount;
endIndex = listSize;
newList = list.subList(startIndex, endIndex);
}
//创建线程类处理数据
MyThread<T> myThread = new MyThread(newList, begin, end) {
@Override
public void method(List list) {
DefaultTransactionDefinition def = new DefaultTransactionDefinition();
def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
TransactionStatus status = transactionManager.getTransaction(def);
//具体执行逻辑交给回调函数
try {
callBack.method(list);
/*if (random.nextInt(2) == 1) {
throw new RuntimeException("模拟异常抛出错误回滚");
}*/
log.warn("多线程事务批量操作执行成功,线程名:{},操作成功数量:{}",Thread.currentThread().getName(), list.size());
} catch (Exception e) {
// 接收异常,处理异常
isError.set(true);
//e.printStackTrace();
log.error("多线程事务批量操作抛错,线程名:{},操作失败数量:{},报错信息:{},{}",Thread.currentThread().getName(),list.size(),e.toString(), e);
}
//计数器减一
end.countDown();
try {
//等待所有线程任务完成,监控是否有异常,有则统一回滚
//log.warn("等待所有任务执行完成,当前时间:{},当前end计数:{}", LocalDateTime.now(), end.getCount());
end.await();
//log.warn("完成所有任务,当前时间:{},当前end计数:{}", LocalDateTime.now(), end.getCount());
if (isError.get()) {
// 事务回滚
transactionManager.rollback(status);
} else {
//事务提交
transactionManager.commit(status);
}
} catch (InterruptedException e) {
e.printStackTrace();
}
}
};
//执行线程
executorService.execute(myThread);
}
//计数器减一,开始执行任务 begin此时为0
begin.countDown();//
//等待任务全部执行完毕,变为0则任务全部完成
end.await();
//关闭线程池
executorService.shutdown();
//不抛错也是可以回滚的
/*if (isError.get()) {
// 主线程抛出自定义的异常
throw new RuntimeException("主线程抛出模拟异常");
}*/
}
//抽象线程类
public abstract class MyThread<T> implements Runnable {
//list:总数据分割后某线程负责执行的数据
private List<T> list;
private CountDownLatch begin, end;
public MyThread(List<T> list, CountDownLatch begin, CountDownLatch end) {
this.list = list;
this.begin = begin;
this.end = end;
}
@Override
public void run() {
try {
begin.await();
//执行程序
method(list);
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
//计数器减一
//end.countDown();
}
}
public abstract void method(List<T> list);
}
//回调接口定义
public interface CallBack<T> {
public void method(List<T> list);
}
public static void main(String[] agrs) {
List<String> list = new ArrayList<>();
for (int i = 0; i < 10; i++) {
list.add("hello" + i);
}
TransactionMultipartExecutor<String> tool = new TransactionMultipartExecutor(3, list);
tool.setCallBack(new CallBack<String>() {
@Override
public void method(List<String> list) {
//总数据分割后某线程负责执行的数据
for (int i = 0; i < list.size(); i++) {
System.out.print(Thread.currentThread().getId() + ":" + list.get(i) + " ");
}
System.out.println();
}
});
try {
tool.excute();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}