java并发包之AbstractQueuedSynchronizer源码分析(三)

此文分析AQS的最后部分,CountDownLatch、CyclicBarrier、Semaphore

CountDownLatch

CountDownLatch是AQS共享模式的使用。中文名为“栅栏”

例子

public class CountDown {
   //对栅栏进行初始化,设置一个参数为2,代表需要有两个任务都到达此栅栏才会继续向下执行。
    public static CountDownLatch cdl=new CountDownLatch(2);
    public static void main(String[] args) {
        Thread t1=new Thread(new Task());
        Thread t2 = new Thread(new Task1());
        t1.start();
        t2.start();
        try {
            //从此处阻塞,等待2个任务都完成后再从此处返回继续向下执行。
            cdl.await();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("ok");
    }
}
class Task implements Runnable{
    @Override
    public void run() {
        try {
            System.out.println("Task is doing something");
            //这句话就是此任务到达这个栅栏。
            CountDown.cdl.countDown();
            System.out.println("Task over");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
class Task1 implements Runnable{
    @Override
    public void run() {
        try {
            System.out.println("Task1 is doing something");
            CountDown.cdl.countDown();
            System.out.println("Task1 over");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

运行结果

Task is doing something
Task1 is doing something
Task over
Task1 over
ok

所以简单来说CountDownLatch其实就是首先初始化一个参数N,主线程调用cdl.await()阻塞,等待其他任务进行cdl.countDown(),也就是对这个参数N进行减1操作,当把N减为0时,从主线程的cdl.await()处返回。
所以这个运行结果是两个线程都进行cdl.countDown()之后再从await处返回输出ok。


来自javadoop.com的图片

源码分析

//构造方法,传入一个不小于0的整数。
public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
              this.sync = new Sync(count);
 }
//内部封装一个Sync 类继承自 AQS。
private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        //将AQS里的state赋值为count。
        Sync(int count) {
            setState(count);
}

从这里可以看出,在进行初始化的时候,首先给AQS里的state赋一个值,也就是参数count,每一个调用await()的线程都会阻塞挂起,等待其他线程进行countDown(),将state-1,当有线程将state减为0的时候,这个线程同时唤醒所有调用await()方法的线程。

await()

调用await()方法的线程会加入到阻塞队列中,等待countDown()方法将state减为0时才能被唤醒返回。
await可以被多个线程调用,所有调用了await的线程都会加入到阻塞队列中。

public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        //如果被中断,抛出InterruptedException()异常。
        if (Thread.interrupted())
            throw new InterruptedException();
        //判断此时state是否为0。
        if (tryAcquireShared(arg) < 0)
            //把当前线程加入到阻塞队列中。等待被唤醒。
            doAcquireSharedInterruptibly(arg);
    }
protected int tryAcquireShared(int acquires) {
            //state==0时返回1,否则返回-1。
            return (getState() == 0) ? 1 : -1;
 }
private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        //把此线程加入到阻塞队列中。
        final Node node = addWaiter(Node.SHARED);
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    //判断此时state是否为0。
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        return;
                    }
                }
                //将其前驱节点的waitStatus置为-1并从此挂起。
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } catch (Throwable t) {
            cancelAcquire(node);
            throw t;
        }
    }

countDown()

public void countDown() {
        sync.releaseShared(1);
    }
public final boolean releaseShared(int arg) {
        //当state被减为0时,tryReleaseShared(arg)才会返回true,执行doReleaseShared()。
        if (tryReleaseShared(arg)) {
            //唤醒阻塞队列中的 线程。
            doReleaseShared();
            return true;
        }
        return false;
    }
//进行对state-1的操作。
protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
//唤醒阻塞队列中的线程
private void doReleaseShared() {
        for (;;) {
            Node h = head;
            //判断阻塞队列是否为null。
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    //这里CAS失败的原因:
                  //因为要唤醒阻塞队列中的所有线程,所以会依次进入此方法,所以可能会出现多个线程进行此CAS,也就是只能有一个会成功。
                    if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    //唤醒阻塞队列的线程
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                          //这里CAS失败是因为此时刚好有一个节点入队,将这个waitStatus设为-1。
                         !h.compareAndSetWaitStatus(0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            //这里有两种情况:
            //1。h==head,因为会依次唤醒进入此方法,当头节点还没有被刚刚唤醒的线程占有时,break直接退出。从这里退出不代表不继续唤醒后继节点的线程,因为唤醒的节点还会继续进入此方法。
           //2。h!=head,头节点被刚刚唤醒的线程占有,重新进入下一轮循环。
            if (h == head)                   // loop if head changed
                break;
        }
    }
//唤醒后继续回到这里。从parkAndCheckInterrupt()中返回。
private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                     //因为此时的state已经被减为0。
                    //所以这时r=1。进入if语句体中。
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        //继续依次唤醒阻塞队列中的其他节点。
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    //从此处返回。继续for循环。
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } catch (Throwable t) {
            cancelAcquire(node);
            throw t;
        }
    }
//因为可以有多个线程调用await,所以要依次唤醒。
private void setHeadAndPropagate(Node node, int propagate) {
        Node h = head; // Record old head for check below
        //把第一个唤醒的节点设为头节点,然后再去唤醒其后继节点,以此类推。
        setHead(node);
        if (propagate > 0 || h == null || h.waitStatus < 0 ||
            (h = head) == null || h.waitStatus < 0) {
            Node s = node.next;
            if (s == null || s.isShared())
                //继续进入到doReleaseShared(),但此时的头节点已经成为了之前唤醒的节点。
                doReleaseShared();
        }
    }

假设有a1,a2两个线程进行await()进入到阻塞队列中,那么进行countDown()把state减为0的那个线程会先唤醒a1,然后a1再去唤醒a2。

CyclicBarrier

中文名称为“可重复使用的栅栏”,它不是使用countDownLatch的共享模式,而是使用了AQS的Condition。


来自javadoop的图片

CyclicBarrier类中的属性与构造方法

public class CyclicBarrier {
    //因为CyclicBarrier 是可以重复使用的,所以每次从开始使用到穿过栅栏当做"一代",或者"一个周期"
    private static class Generation {
        Generation() {}                 // prevent access constructor creation
        boolean broken;                 // initially false
    }

    /** The lock for guarding barrier entry */
    //基于Condition来实现,需要lock。
    private final ReentrantLock lock = new ReentrantLock();
    /** Condition to wait on until tripped */
    //和Condition一样,先初始化一个条件队列。
    private final Condition trip = lock.newCondition();
    /** The number of parties */
    //参与的线程数。
    private final int parties;
    /** The command to run when tripped */
    //代表所有线程到达此栅栏时,先执行这个任务。
    //可以设置为null。
    private final Runnable barrierCommand;
    /** The current generation */
    //当前的“代”或者“周期”
    private Generation generation = new Generation();
     //这个值初始为 parties,每到达此栅栏一个线程就对count-1。
    private int count;
    //构造方法。
    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }
}

await()

public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }
private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        final ReentrantLock lock = this.lock;
        //和condition一样,需要先获取锁。
        lock.lock();
        try {
            final Generation g = generation;
            //检查栅栏是否被打破,如果为true,抛出异常。
            if (g.broken)
                throw new BrokenBarrierException();
            //判断线程是否中断,中断则打破此栅栏,抛出异常。
            if (Thread.interrupted()) {
                breakBarrier();
                throw new InterruptedException();
            }
            //调用await()的线程对count-1。
            int index = --count;
            //如果count减为0,代表所有的线程已经到达此栅栏,唤醒所有阻塞的线程。
            if (index == 0) {  // tripped
                boolean ranAction = false;
                try {
                    // 如果在初始化的时候,指定了通过栅栏前需要执行的操作,在这里会得到执行
                    final Runnable command = barrierCommand;
                    if (command != null)
                        command.run();
                    //若可以把ranAction为true,说明执行command.run();时没有发生异常。
                    ranAction = true;
                    //唤醒所有等待的线程,开启新的一“代”。
                    nextGeneration();
                    return 0;
                } finally {
                    //进入这个if条件体中说明在 command.run();时,发生异常,打破此栅栏。
                    if (!ranAction)
                        breakBarrier();
                }
            }
            //进入这个for循环中说明此线程不是最后一个调用await()方法的线程。即还没有把count减为0。
            //所以需要把此线程从这里加入到条件队列中,并将其挂起。
            // loop until tripped, broken, interrupted, or timed out
            for (;;) {
                try {
                   // 判断此await()方法是否带有超时机制。
                    if (!timed)
                        //把此线程加入到条件队列中并挂起等待唤醒。
                        trip.await();
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    //捕获到异常进入这里说明在trip.await()期间线程被中断。
                    if (g == generation && ! g.broken) {
                        //打破此栅栏。
                        breakBarrier();
                        //抛出异常。
                        throw ie;
                    } else {
                        // We're about to finish waiting even if we had not
                        // been interrupted, so this interrupt is deemed to
                        // "belong" to subsequent execution.
                        //进入这里说明g!=generation,说明新的一代已经产生,即最后一个线程await()执行完成。
                        //所以没有必要再抛出异常,记录下这个信息即可。
                        Thread.currentThread().interrupt();
                    }
                }
                //唤醒返回后判断此栅栏是否被打破。
                if (g.broken)
                    throw new BrokenBarrierException();
                //正常情况下,最后一个到达栅栏的线程会进行nextGeneration();把所有在条件队列中的线程加入到阻塞队列,然后开启新的“代”,最后释放锁。
                //其他线程从trip.await()中获取锁并返回后,到达此处,即g!=generation,从此处return退出.
                if (g != generation)
                    return index;
                //唤醒后发现超时,打破栅栏,抛出异常。
                if (timed && nanos <= 0L) {
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            lock.unlock();
        }
    }
//唤醒所有在条件队列中的线程,开启新的一“代”。
private void nextGeneration() {
        // signal completion of last generation
        trip.signalAll();
        // set up next generation
        count = parties;
        generation = new Generation();
    }
//打破此栅栏,设置broken为true,唤醒所有在条件队列中的线程。
private void breakBarrier() {
        generation.broken = true;
        count = parties;
        trip.signalAll();
    }

打破栅栏的条件:
1.中断,被阻塞挂起的线程在等待中被中断,会breakBarrier(),抛出异常。
2.超时,打破栅栏,抛出异常。
3.barrierCommand.run()时发生异常,打破栅栏。

Semaphore

Semaphore也是AQS共享模式的使用。
创建 Semaphore 实例的时候,需要一个参数 permits,这个是设置给 AQS 的 state 的,然后每个线程调用 acquire 的时候,执行 state = state - 1,release 的时候执行 state = state + 1,所以在acquire 的时候,如果 state = 0,说明没有信号了,需要等待其他线程 release。

构造方法

public Semaphore(int permits) {
    sync = new NonfairSync(permits);
}
public Semaphore(int permits, boolean fair) {
    sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}

与ReentrantLock类似,存在公平与非公平策略。

acquire()

public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        //判断此时state是否小于0,小于0则进入此条件体,即当前已经没有多余信号量。
        if (tryAcquireShared(arg) < 0)
             //加入到阻塞队列中挂起,等待release()。
            doAcquireSharedInterruptibly(arg);
    }

公平与非公平tryAcquireShared(arg)。

//公平
protected int tryAcquireShared(int acquires) {
            for (;;) {
                //区别就是公平策略里先判断阻塞队列中是否有其他线程比我先排队。
                //非公平策略无需判断。
                if (hasQueuedPredecessors())
                    return -1;
                int available = getState();
                int remaining = available - acquires;
                //这里需要注意因为是 ||运算符,所以把state减到小于0时直接return remaining,不用进行CAS改变state的值。
                //也就是说state最小为0。
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }
//非公平
protected int tryAcquireShared(int acquires) {
    return nonfairTryAcquireShared(acquires);
}
final int nonfairTryAcquireShared(int acquires) {
            for (;;) {
                int available = getState();
                int remaining = available - acquires;
                if (remaining < 0 ||
                    compareAndSetState(available, remaining))
                    return remaining;
            }
        }

doAcquireSharedInterruptibly(arg),把当前线程加入到阻塞队列中,等待唤醒。

private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        //加入到阻塞队列中。
        final Node node = addWaiter(Node.SHARED);
        try {
            for (;;) {
                final Node p = node.predecessor();
                if (p == head) {
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    //挂起。
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } catch (Throwable t) {
            cancelAcquire(node);
            throw t;
        }
    }

release()

public void release() {
        sync.releaseShared(1);
    }
public final boolean releaseShared(int arg) {
        //进行state=state+1操作。
        if (tryReleaseShared(arg)) {
            //唤醒在阻塞队列中的线程。
            doReleaseShared();
            return true;
        }
        return false;
    }
protected final boolean tryReleaseShared(int releases) {
            //通过自旋做state+1操作。
            for (;;) {
                int current = getState();
                int next = current + releases;
                if (next < current) // overflow
                    throw new Error("Maximum permit count exceeded");
                if (compareAndSetState(current, next))
                    return true;
            }
        }
private void doReleaseShared() {
        //唤醒等待的线程。
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !h.compareAndSetWaitStatus(0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

其实与Condition的套路差不多。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容