Java CountDownLatch、CyclicBarrier、Semaphore源码分析(基于API 29 JDK8)

CountDownLatch

在多线程的情况下,主线程需要等待子线程执行完毕之后才能进行接下来的操作,在CountDownLatch出现之前,一般通过join来实现,但是join不够灵活,不能满足丰富场景下的需求,所以CountDownLatch类诞生了。举个例子:

public class ExampleTest {
    @Test
    public void main() {
        ExecutorService service = Executors.newFixedThreadPool(3);
        for (int i = 0; i < 3; i++) {
            service.execute(runnable);
        }
        try {
            latch.await();
            System.out.println("All Thread finish");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    private CountDownLatch latch = new CountDownLatch(3);
    private Runnable runnable = new Runnable() {
        @Override
        public void run() {
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName()+" : finish work");
            latch.countDown();
        }
    };
}

/***************************************
输出:
pool-1-thread-1 : finish work
pool-1-thread-2 : finish work
pool-1-thread-3 : finish work
All Thread finish
***************************************/

接下来看CountDownLatch的源码,从构造函数开始:

public class CountDownLatch {
    ...
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    ...
}

只有带一个int参数的构造函数,这个Sync类是它的一个内部类,查看源码:

private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        // 构造函数调用了setState,是父类的方法
        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // 循环进行CAS,直到当前线程成功完成CAS是计数器值(state)减1并更新到state
            for (;;) {
                // 获取volatile变量的state
                int c = getState();
                // state为0直接返回
                if (c == 0)
                    return false;
                int nextc = c - 1;
                // cas让state-1
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

先看一下AbstractQueuedSynchronizer的setState/getState方法做了些什么:

public abstract class AbstractQueuedSynchronizer
    extends AbstractOwnableSynchronizer
    implements java.io.Serializable {
    ...
    private volatile int state;
    protected final void setState(int newState) {
        state = newState;
    }
    protected final int getState() {
        return state;
    }
    ...
}

从上面可以看到CountDownLatch的初始化设置了一个volatile的变量state,接下来看countDown方法做了什么:

    // CountDownLatch.java
    public void countDown() {
        sync.releaseShared(1);
    }

    // AbstractQueuedSynchronizer.java
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            // 释放资源
            doReleaseShared();
            return true;
        }
        return false;
    }

    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !h.compareAndSetWaitStatus(0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

可以看到releaseShared是调用了tryReleaseShared,在循环进行CAS,直到当前线程成功完成CAS是计数器值(state)减1并更新到state,CAS成功之后调用doReleaseShared释放资源。之后看await方法做了些什么:

    // CountDownLatch.java
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    // AbstractQueuedSynchronizer.java
    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        // state不为0的时候进入等待队列
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }

    private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        try {
            // 阻塞
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    // 当state为0时,返回
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        return;
                    }
                }
                // 当节点获取失败或者中断的时候抛出异常
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } catch (Throwable t) {
            cancelAcquire(node);
            throw t;
        }
    }

调用的Sync的acquireSharedInterruptibly方法,当state不为0的时候,进入doAcquireSharedInterruptibly阻塞,当state为0时返回,或中断时抛出异常。再看带了超时参数的await方法:

    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        return tryAcquireShared(arg) >= 0 ||
            doAcquireSharedNanos(arg, nanosTimeout);
    }

    private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
            throws InterruptedException {
        if (nanosTimeout <= 0L)
            return false;
        final long deadline = System.nanoTime() + nanosTimeout;
        final Node node = addWaiter(Node.SHARED);
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        return true;
                    }
                }
                nanosTimeout = deadline - System.nanoTime();
                if (nanosTimeout <= 0L) {
                    cancelAcquire(node);
                    return false;
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    nanosTimeout > SPIN_FOR_TIMEOUT_THRESHOLD)
                    LockSupport.parkNanos(this, nanosTimeout);
                if (Thread.interrupted())
                    throw new InterruptedException();
            }
        } catch (Throwable t) {
            cancelAcquire(node);
            throw t;
        }
    }

与await无参方法不一样的是,doAcquireSharedNanos方法多了一个nanosTimeout参数,当nanosTimeout小于0的时候,释放资源并返回false,程序主线程将继续运行。

CyclicBarrier

从上面的分析可以看到,当CountDownLatch执行countDown到state为0的时候,就结束了,没有重置的办法,因此CyclicBarrier来了,CyclicBarrier是回环屏障的意思,它可以让一组线程全部达到一个状态后再全部同步执行。这里之所以叫回环是因为当所有等待线程执行完毕,重置CyclicBarrier的状态后可以被重用。写一个测试用例:

public class ExampleTest {
    @Test
    public void main() {
        ExecutorService service = Executors.newFixedThreadPool(3);
        for (int i = 0; i < 3; i++) {
            service.execute(runnable);
        }
    }
    
    private CyclicBarrier barrier = new CyclicBarrier(3, new Runnable() {
        @Override
        public void run() {
            System.out.println(Thread.currentThread().getName() + " : finish work");
        }
    });

    private Runnable runnable = new Runnable() {
        @Override
        public void run() {
            try {
                System.out.println(Thread.currentThread().getName() + " : start work");
                barrier.await();
                System.out.println(Thread.currentThread().getName() + " : barrier out 1");
                barrier.await();
                System.out.println(Thread.currentThread().getName() + " : barrier out 2");
            } catch (BrokenBarrierException | InterruptedException e) {
                e.printStackTrace();
            }
        }
    };
}

/***************************************
输出:
pool-1-thread-1 : start work
pool-1-thread-2 : start work
pool-1-thread-3 : start work
pool-1-thread-3 : finish work
pool-1-thread-3 : barrier out 1
pool-1-thread-2 : barrier out 1
pool-1-thread-1 : barrier out 1
pool-1-thread-1 : finish work
pool-1-thread-1 : barrier out 2
pool-1-thread-2 : barrier out 2
pool-1-thread-3 : barrier out 2
***************************************/

测试用例新建了一个CyclicBarrier对象,传递参数为计数器初始值和当计数器为0时执行的runnable。一开始计数器的值为3,当第一个线程调用await方法时,计数器减1,此时计数器不为0,线程阻塞,直到3个线程全部执行await,计数器为0,最后一个进入await的线程执行CyclicBarrier中的runnable,执行完毕后结束阻塞,并唤醒其他线程,执行完barrier out 1的任务之后再次阻塞在await方法,这是单个CountDownLatch无法完成的。分析源码实现,首先是构造函数:

public class CyclicBarrier {
    private final ReentrantLock lock = new ReentrantLock();
    private Generation generation = new Generation();
    private final Condition trip = lock.newCondition();

    private final Runnable barrierCommand;
    private final int parties;
    private int count;

    private static class Generation {
        boolean broken;         // initially false
    }

    public CyclicBarrier(int parties) {
        this(parties, null);
    }

    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }
}

初始化了parties、count和barrierCommand,接着看await方法:

    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }
 
    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        // 获取ReentrantLock
        final ReentrantLock lock = this.lock;
        // 上锁
        lock.lock();
        try {
            // 默认为false
            final Generation g = generation;

            if (g.broken)
                throw new BrokenBarrierException();
            // 中断
            if (Thread.interrupted()) {
                // 结束回环并抛出异常
                breakBarrier();
                throw new InterruptedException();
            }
            // count自减1
            int index = --count;
            if (index == 0) {  // tripped
                // 执行barrierCommand
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    ranAction = true;
                    // 更新状态并唤醒所有处于锁定状态的线程
                    nextGeneration();
                    return 0;
                } finally {
                    // 如果没有正常返回,则结束回环
                    if (!ranAction)
                        breakBarrier();
                }
            }

            // 循环直到count=0,屏障破坏,中断,或者超时
            for (;;) {
                try {
                    if (!timed)
                        // 阻塞,直到收到通知
                        trip.await();
                    else if (nanos > 0L)
                        // 阻塞,直到超时
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    // 发生中断
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        Thread.currentThread().interrupt();
                    }
                }

                if (g.broken)
                    throw new BrokenBarrierException();

                if (g != generation)
                    return index;

                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            // 解锁
            lock.unlock();
        }
    }

    private void breakBarrier() {
        generation.broken = true;
        count = parties;
        trip.signalAll();
    }

从上面可以看到,当进入await的时候,会通过ReentrantLock对代码段进行一个上锁,操作count自减1,当减为0的时候,执行barrierCommand并调用trip.signalAll()来唤醒所有阻塞中的线程,并将count重新初始化为parties,否则在后续中通过调用trip.await()或者trip.awaitNanos(nanos)进入阻塞状态并释放锁,直到收到通知信号加入锁的竞争中,获取到锁之后在finally中释放锁,其他线程依次如此,最终所有线程往下继续运行。

Semaphore

Semaphore也是java的一个同步器,与CountDownLatch、CyclicBarrier不同的是,它不需要在初始化的时候指定同步线程的个数,而是在需要同步的地方调用acquire方法时,指定需要同步的线程数。

public class ExampleTest {
    private Semaphore semaphore = new Semaphore(0);

    @Test
    public void main() throws InterruptedException {
        ExecutorService service = Executors.newFixedThreadPool(3);
        for (int i = 0; i < 3; i++) {
            service.execute(runnable);
        }
        System.out.println(Thread.currentThread().getName() + " : acquire");
        semaphore.acquire(3);
        System.out.println(Thread.currentThread().getName() + " : release");
    }
    
    private Runnable runnable = new Runnable() {
        @Override
        public void run() {
            System.out.println(Thread.currentThread().getName() + " : start work");
            semaphore.release();
        }
    };
}
/***************************************
输出:
main : acquire
pool-1-thread-1 : start work
pool-1-thread-2 : start work
pool-1-thread-3 : start work
main : release
***************************************/

上述代码创建了一个Semaphore实例,构造函数传参为0,说明当前信号量的计数器的值为0,然后在main中向线程池添加了3个线程任务,在子线程中调用release方法,在main的最后调用acquire方法,传入参数为线程数3,之后进入阻塞状态,等待信号量的计数变为3。接下来看源码,从构造函数开始:

public class Semaphore implements java.io.Serializable {
    private final Sync sync;
    public Semaphore(int permits) {
        sync = new NonfairSync(permits);
    }
    public Semaphore(int permits, boolean fair) {
        sync = fair ? new FairSync(permits) : new NonfairSync(permits);
    }
}

从上面可以看到Semaphore提供了两种构造方法,其中permits用于初始化信号量,fair用于确定sync的实例类型,默认是非公平,Sync是内部类继承于AbstractQueuedSynchronizer ,FairSync和NonfairSync是Sync的子类。

    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 1192457210091910933L;

        Sync(int permits) {
            setState(permits);
        }

        final int getPermits() {
            return getState();
        }

        final int nonfairTryAcquireShared(int acquires) {
            for (;;) {
                int available = getState();
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }

        protected final boolean tryReleaseShared(int releases) {
            for (;;) {
                int current = getState();
                int next = current + releases;
                if (next < current) // overflow
                    throw new Error("Maximum permit count exceeded");
                if (compareAndSetState(current, next))
                    return true;
            }
        }

        final void reducePermits(int reductions) {
            for (;;) {
                int current = getState();
                int next = current - reductions;
                if (next > current) // underflow
                    throw new Error("Permit count underflow");
                if (compareAndSetState(current, next))
                    return;
            }
        }

        final int drainPermits() {
            for (;;) {
                int current = getState();
                if (current == 0 || compareAndSetState(current, 0))
                    return current;
            }
        }
    }

    /**
     * NonFair version
     */
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -2694183684443567898L;

        NonfairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
            return nonfairTryAcquireShared(acquires);
        }
    }

    /**
     * Fair version
     */
    static final class FairSync extends Sync {
        private static final long serialVersionUID = 2014338818796000944L;

        FairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
            for (;;) {
                if (hasQueuedPredecessors())
                    return -1;
                int available = getState();
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
    }

AbstractQueuedSynchronizer的内容之前已经分析过了,先从Semaphore的acquire方法入手,

    public void acquire(int permits) throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireSharedInterruptibly(permits);
    }

    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        // 调用sync子类方法尝试获取,默认使用非公平策略
        if (tryAcquireShared(arg) < 0)
            // 如果获取失败则添加到阻塞队列,然后再次尝试,继续失败则调用part方法挂起当前线程
            doAcquireSharedInterruptibly(arg);
    }
    // 非公平策略
    protected int tryAcquireShared(int acquires) {
        return nonfairTryAcquireShared(acquires);
    }

    final int nonfairTryAcquireShared(int acquires) {
        for (;;) {
            // 当前信号量
            int available = getState();
            // 剩余值
            int remaining = available - acquires;
            // 如果当前剩余值小于0或者CAS设置成功则返回
            if (remaining < 0 ||
                compareAndSetState(available, remaining))
                return remaining;
        }
    }

    // 公平策略
    protected int tryAcquireShared(int acquires) {
        for (;;) {
            // 查看当前线程的前驱节点是否也在等待获取资源
            // 如果是则放弃获取并加入AQS阻塞队列,否则就去获取资源
            if (hasQueuedPredecessors())
                return -1;
            int available = getState();
            int remaining = available - acquires;
            if (remaining < 0 ||
                compareAndSetState(available, remaining))
                return remaining;
        }
    }

    public final boolean hasQueuedPredecessors() {
        Node t = tail; // Read fields in reverse initialization order
        Node h = head;
        Node s;
        return h != t &&
            ((s = h.next) == null || s.thread != Thread.currentThread());
    }

acquire会调用tryAcquireShared,该方法在构造函数中存在公平与非公平两种策略,其中非公平策略下线程会直接尝试获取资源,而公平策略通过hasQueuedPredecessors节点的前节点是否也在等待获取资源,如果有前节点则放弃获取并加入阻塞队列,否则通过获取信号量,并计算差值,如果差值小于0或者CAS操作state成功时返回,从上面分析不难发现Semaphore也是支持回环的,每次调用acquire会更新信号量,相当于CyclicBarrier中将信号量重新初始化为备份过的初始值一样。接下来看release方法:

    public void release() {
        sync.releaseShared(1);
    }

    public final boolean releaseShared(int arg) {
        // 尝试释放资源
        if (tryReleaseShared(arg)) {
            // 资源释放成功则调用part方法唤醒AQS队列里面最先挂起的线程
            doReleaseShared();
            return true;
        }
        return false;
    }

    protected final boolean tryReleaseShared(int releases) {
        for (;;) {
            // 当前信号量
            int current = getState();
            // 信号量增加release
            int next = current + releases;
            // releases<0的情况
            if (next < current) // overflow
                throw new Error("Maximum permit count exceeded");
            // CAS更新信号量,保证原子性
            if (compareAndSetState(current, next))
                return true;
        }
    }

    private void doReleaseShared() {
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !h.compareAndSetWaitStatus(0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

从上面可以看到release方法是调用releaseShared方法并传参1,在releaseShared方法中调用tryReleaseShared方法来释放资源,通过CAS操作更新信号量,释放资源成功后,调用doReleaseShared方法唤醒AQS队列中最先挂起的线程,结束release。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。