多线程分工处理list数据

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.CountDownLatch;

public class ProcessorThread<T> implements Runnable {

    private static final Logger LOGGER = LoggerFactory.getLogger(ProcessorThread.class);

    private List<T> waitingProcessList;

    private CountDownLatch countDownLatch;
    private ProcessorRunnable<T> runnable;

    public ProcessorThread(ProcessorRunnable<T> runnable, List<T> waitingProcessList, CountDownLatch countDownLatch) {
        this.runnable = runnable;
        this.countDownLatch = countDownLatch;
        this.waitingProcessList = waitingProcessList;
    }

    public void setWaitingProcessList(List<T> waitingProcessList) {
        this.waitingProcessList = waitingProcessList;
    }

    public void setCountDownLatch(CountDownLatch countDownLatch) {
        this.countDownLatch = countDownLatch;
    }

    @Override
    public void run() {
        try {
            this.runnable.handle(this.waitingProcessList);
        } catch (Exception e) {
            LOGGER.error(e.getLocalizedMessage(), e);
        } finally {
            if (this.countDownLatch != null) {
                this.countDownLatch.countDown();
            }
        }
    }

    public interface ProcessorRunnable<T> {

        /**
         * 对数据进行处理.
         *
         * @param waitingProcessList 等待处理的list.
         */
        void handle(List<T> waitingProcessList);

    }

}
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.springframework.util.Assert;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
 * 多线程处理工具,使用多个线程对list进行处理.
 *
 * @author oneal
 */
public class MultiThreadProcessor<T> {

    private int threadNum = 10;
    private ThreadFactory namedThreadFactory = new ThreadFactoryBuilder().setNameFormat("multi-thread-processor-runner-%d").build();
    private ExecutorService executorService = new ThreadPoolExecutor(this.threadNum, this.threadNum, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>(), this.namedThreadFactory);
    private List<T> waitingProcessList;

    public MultiThreadProcessor(List<T> waitingProcessList) {
        this.waitingProcessList = waitingProcessList;
    }

    public MultiThreadProcessor(List<T> waitingProcessList, int threadNum) {
        this.waitingProcessList = waitingProcessList;
        this.threadNum = threadNum;
    }

    public void run(ProcessorThread.ProcessorRunnable<T> runnable) throws InterruptedException {
        Assert.isTrue(this.waitingProcessList != null && !this.waitingProcessList.isEmpty(), "需要处理的数据为空.");
        CountDownLatch countDownLatch = new CountDownLatch(threadNum);
        /*每个线程处理多少数量,
            如果list的数量大于threadNum*2倍并且可以整除,(list 大小 100,threadNum:10,10个线程,每个线程处理10条.)
            如果list的数量大于threadNum*2倍并且不能整除,(list 大小 101,threadNum:10,10个线程,前面9个线程,每个11条,最后一个线程1条)
            如果list的数量小于threadNum*2倍并且大于threadNum(list 大小 13,threadNum:10,7个线程进行处理,前面6个线程,每个2条,最后一个线程1条)
            如果list的数量小于等于threadNum(list 大小 7,threadNum:10,每个线程处理1条.)
        */
        int perSize = this.waitingProcessList.size() % threadNum == 0 ? this.waitingProcessList.size() / threadNum : (this.waitingProcessList.size() / threadNum) + 1;
        for (int i = 0; i < threadNum; i++) {
            int start = i * perSize;
            if (start >= this.waitingProcessList.size()) {
                for (int j = 0; j < threadNum - i; j++) {
                    countDownLatch.countDown();
                }
                break;
            }
            int end = Math.min((i + 1) * perSize, this.waitingProcessList.size());
            ProcessorThread<T> thread = new ProcessorThread<>(runnable, this.waitingProcessList.subList(start, end), countDownLatch);
            this.executorService.execute(thread);
        }
        countDownLatch.await();
        this.executorService.shutdown();
    }


}
public class MultiThreadProcessorTest {

    private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadProcessorTest.class);

    public static void main(String[] args) {
        testProcessor();
    }
    public static void testProcessor() {

        List<Integer> iList = Lists.newLinkedList();
        for (int i = 0; i < 100; i++) {
            iList.add(i);
        }
        MultiThreadProcessor<Integer> multiThreadProcessor = new MultiThreadProcessor<>(iList);
        try {
            multiThreadProcessor.run(new ProcessorThread.ProcessorRunnable<Integer>() {
                @Override
                public void handle(List<Integer> waitingProcessList) {
                    for (Integer integer : waitingProcessList) {
                        System.out.println(Thread.currentThread().getName() + "---------" + integer);
                    }
                }
            });
        } catch (InterruptedException e) {
            LOGGER.error(e.getLocalizedMessage(), e);
        }

    }

}
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容