CountDownLatch
在多线程的情况下,主线程需要等待子线程执行完毕之后才能进行接下来的操作,在CountDownLatch出现之前,一般通过join来实现,但是join不够灵活,不能满足丰富场景下的需求,所以CountDownLatch类诞生了。举个例子:
public class ExampleTest {
@Test
public void main() {
ExecutorService service = Executors.newFixedThreadPool(3);
for (int i = 0; i < 3; i++) {
service.execute(runnable);
}
try {
latch.await();
System.out.println("All Thread finish");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
private CountDownLatch latch = new CountDownLatch(3);
private Runnable runnable = new Runnable() {
@Override
public void run() {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread().getName()+" : finish work");
latch.countDown();
}
};
}
/***************************************
输出:
pool-1-thread-1 : finish work
pool-1-thread-2 : finish work
pool-1-thread-3 : finish work
All Thread finish
***************************************/
接下来看CountDownLatch的源码,从构造函数开始:
public class CountDownLatch {
...
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
...
}
只有带一个int参数的构造函数,这个Sync类是它的一个内部类,查看源码:
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
// 构造函数调用了setState,是父类的方法
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// 循环进行CAS,直到当前线程成功完成CAS是计数器值(state)减1并更新到state
for (;;) {
// 获取volatile变量的state
int c = getState();
// state为0直接返回
if (c == 0)
return false;
int nextc = c - 1;
// cas让state-1
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
先看一下AbstractQueuedSynchronizer的setState/getState方法做了些什么:
public abstract class AbstractQueuedSynchronizer
extends AbstractOwnableSynchronizer
implements java.io.Serializable {
...
private volatile int state;
protected final void setState(int newState) {
state = newState;
}
protected final int getState() {
return state;
}
...
}
从上面可以看到CountDownLatch的初始化设置了一个volatile的变量state,接下来看countDown方法做了什么:
// CountDownLatch.java
public void countDown() {
sync.releaseShared(1);
}
// AbstractQueuedSynchronizer.java
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
// 释放资源
doReleaseShared();
return true;
}
return false;
}
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!h.compareAndSetWaitStatus(0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
可以看到releaseShared是调用了tryReleaseShared,在循环进行CAS,直到当前线程成功完成CAS是计数器值(state)减1并更新到state,CAS成功之后调用doReleaseShared释放资源。之后看await方法做了些什么:
// CountDownLatch.java
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
// AbstractQueuedSynchronizer.java
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// state不为0的时候进入等待队列
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
try {
// 阻塞
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
// 当state为0时,返回
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
return;
}
}
// 当节点获取失败或者中断的时候抛出异常
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} catch (Throwable t) {
cancelAcquire(node);
throw t;
}
}
调用的Sync的acquireSharedInterruptibly方法,当state不为0的时候,进入doAcquireSharedInterruptibly阻塞,当state为0时返回,或中断时抛出异常。再看带了超时参数的await方法:
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
return tryAcquireShared(arg) >= 0 ||
doAcquireSharedNanos(arg, nanosTimeout);
}
private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (nanosTimeout <= 0L)
return false;
final long deadline = System.nanoTime() + nanosTimeout;
final Node node = addWaiter(Node.SHARED);
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
return true;
}
}
nanosTimeout = deadline - System.nanoTime();
if (nanosTimeout <= 0L) {
cancelAcquire(node);
return false;
}
if (shouldParkAfterFailedAcquire(p, node) &&
nanosTimeout > SPIN_FOR_TIMEOUT_THRESHOLD)
LockSupport.parkNanos(this, nanosTimeout);
if (Thread.interrupted())
throw new InterruptedException();
}
} catch (Throwable t) {
cancelAcquire(node);
throw t;
}
}
与await无参方法不一样的是,doAcquireSharedNanos方法多了一个nanosTimeout参数,当nanosTimeout小于0的时候,释放资源并返回false,程序主线程将继续运行。
CyclicBarrier
从上面的分析可以看到,当CountDownLatch执行countDown到state为0的时候,就结束了,没有重置的办法,因此CyclicBarrier来了,CyclicBarrier是回环屏障的意思,它可以让一组线程全部达到一个状态后再全部同步执行。这里之所以叫回环是因为当所有等待线程执行完毕,重置CyclicBarrier的状态后可以被重用。写一个测试用例:
public class ExampleTest {
@Test
public void main() {
ExecutorService service = Executors.newFixedThreadPool(3);
for (int i = 0; i < 3; i++) {
service.execute(runnable);
}
}
private CyclicBarrier barrier = new CyclicBarrier(3, new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + " : finish work");
}
});
private Runnable runnable = new Runnable() {
@Override
public void run() {
try {
System.out.println(Thread.currentThread().getName() + " : start work");
barrier.await();
System.out.println(Thread.currentThread().getName() + " : barrier out 1");
barrier.await();
System.out.println(Thread.currentThread().getName() + " : barrier out 2");
} catch (BrokenBarrierException | InterruptedException e) {
e.printStackTrace();
}
}
};
}
/***************************************
输出:
pool-1-thread-1 : start work
pool-1-thread-2 : start work
pool-1-thread-3 : start work
pool-1-thread-3 : finish work
pool-1-thread-3 : barrier out 1
pool-1-thread-2 : barrier out 1
pool-1-thread-1 : barrier out 1
pool-1-thread-1 : finish work
pool-1-thread-1 : barrier out 2
pool-1-thread-2 : barrier out 2
pool-1-thread-3 : barrier out 2
***************************************/
测试用例新建了一个CyclicBarrier对象,传递参数为计数器初始值和当计数器为0时执行的runnable。一开始计数器的值为3,当第一个线程调用await方法时,计数器减1,此时计数器不为0,线程阻塞,直到3个线程全部执行await,计数器为0,最后一个进入await的线程执行CyclicBarrier中的runnable,执行完毕后结束阻塞,并唤醒其他线程,执行完barrier out 1的任务之后再次阻塞在await方法,这是单个CountDownLatch无法完成的。分析源码实现,首先是构造函数:
public class CyclicBarrier {
private final ReentrantLock lock = new ReentrantLock();
private Generation generation = new Generation();
private final Condition trip = lock.newCondition();
private final Runnable barrierCommand;
private final int parties;
private int count;
private static class Generation {
boolean broken; // initially false
}
public CyclicBarrier(int parties) {
this(parties, null);
}
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
}
初始化了parties、count和barrierCommand,接着看await方法:
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
// 获取ReentrantLock
final ReentrantLock lock = this.lock;
// 上锁
lock.lock();
try {
// 默认为false
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
// 中断
if (Thread.interrupted()) {
// 结束回环并抛出异常
breakBarrier();
throw new InterruptedException();
}
// count自减1
int index = --count;
if (index == 0) { // tripped
// 执行barrierCommand
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
// 更新状态并唤醒所有处于锁定状态的线程
nextGeneration();
return 0;
} finally {
// 如果没有正常返回,则结束回环
if (!ranAction)
breakBarrier();
}
}
// 循环直到count=0,屏障破坏,中断,或者超时
for (;;) {
try {
if (!timed)
// 阻塞,直到收到通知
trip.await();
else if (nanos > 0L)
// 阻塞,直到超时
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
// 发生中断
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
// 解锁
lock.unlock();
}
}
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
从上面可以看到,当进入await的时候,会通过ReentrantLock对代码段进行一个上锁,操作count自减1,当减为0的时候,执行barrierCommand并调用trip.signalAll()来唤醒所有阻塞中的线程,并将count重新初始化为parties,否则在后续中通过调用trip.await()或者trip.awaitNanos(nanos)进入阻塞状态并释放锁,直到收到通知信号加入锁的竞争中,获取到锁之后在finally中释放锁,其他线程依次如此,最终所有线程往下继续运行。
Semaphore
Semaphore也是java的一个同步器,与CountDownLatch、CyclicBarrier不同的是,它不需要在初始化的时候指定同步线程的个数,而是在需要同步的地方调用acquire方法时,指定需要同步的线程数。
public class ExampleTest {
private Semaphore semaphore = new Semaphore(0);
@Test
public void main() throws InterruptedException {
ExecutorService service = Executors.newFixedThreadPool(3);
for (int i = 0; i < 3; i++) {
service.execute(runnable);
}
System.out.println(Thread.currentThread().getName() + " : acquire");
semaphore.acquire(3);
System.out.println(Thread.currentThread().getName() + " : release");
}
private Runnable runnable = new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + " : start work");
semaphore.release();
}
};
}
/***************************************
输出:
main : acquire
pool-1-thread-1 : start work
pool-1-thread-2 : start work
pool-1-thread-3 : start work
main : release
***************************************/
上述代码创建了一个Semaphore实例,构造函数传参为0,说明当前信号量的计数器的值为0,然后在main中向线程池添加了3个线程任务,在子线程中调用release方法,在main的最后调用acquire方法,传入参数为线程数3,之后进入阻塞状态,等待信号量的计数变为3。接下来看源码,从构造函数开始:
public class Semaphore implements java.io.Serializable {
private final Sync sync;
public Semaphore(int permits) {
sync = new NonfairSync(permits);
}
public Semaphore(int permits, boolean fair) {
sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}
}
从上面可以看到Semaphore提供了两种构造方法,其中permits用于初始化信号量,fair用于确定sync的实例类型,默认是非公平,Sync是内部类继承于AbstractQueuedSynchronizer ,FairSync和NonfairSync是Sync的子类。
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 1192457210091910933L;
Sync(int permits) {
setState(permits);
}
final int getPermits() {
return getState();
}
final int nonfairTryAcquireShared(int acquires) {
for (;;) {
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
protected final boolean tryReleaseShared(int releases) {
for (;;) {
int current = getState();
int next = current + releases;
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
if (compareAndSetState(current, next))
return true;
}
}
final void reducePermits(int reductions) {
for (;;) {
int current = getState();
int next = current - reductions;
if (next > current) // underflow
throw new Error("Permit count underflow");
if (compareAndSetState(current, next))
return;
}
}
final int drainPermits() {
for (;;) {
int current = getState();
if (current == 0 || compareAndSetState(current, 0))
return current;
}
}
}
/**
* NonFair version
*/
static final class NonfairSync extends Sync {
private static final long serialVersionUID = -2694183684443567898L;
NonfairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
}
/**
* Fair version
*/
static final class FairSync extends Sync {
private static final long serialVersionUID = 2014338818796000944L;
FairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
for (;;) {
if (hasQueuedPredecessors())
return -1;
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
}
AbstractQueuedSynchronizer的内容之前已经分析过了,先从Semaphore的acquire方法入手,
public void acquire(int permits) throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireSharedInterruptibly(permits);
}
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// 调用sync子类方法尝试获取,默认使用非公平策略
if (tryAcquireShared(arg) < 0)
// 如果获取失败则添加到阻塞队列,然后再次尝试,继续失败则调用part方法挂起当前线程
doAcquireSharedInterruptibly(arg);
}
// 非公平策略
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
final int nonfairTryAcquireShared(int acquires) {
for (;;) {
// 当前信号量
int available = getState();
// 剩余值
int remaining = available - acquires;
// 如果当前剩余值小于0或者CAS设置成功则返回
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
// 公平策略
protected int tryAcquireShared(int acquires) {
for (;;) {
// 查看当前线程的前驱节点是否也在等待获取资源
// 如果是则放弃获取并加入AQS阻塞队列,否则就去获取资源
if (hasQueuedPredecessors())
return -1;
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
public final boolean hasQueuedPredecessors() {
Node t = tail; // Read fields in reverse initialization order
Node h = head;
Node s;
return h != t &&
((s = h.next) == null || s.thread != Thread.currentThread());
}
acquire会调用tryAcquireShared,该方法在构造函数中存在公平与非公平两种策略,其中非公平策略下线程会直接尝试获取资源,而公平策略通过hasQueuedPredecessors节点的前节点是否也在等待获取资源,如果有前节点则放弃获取并加入阻塞队列,否则通过获取信号量,并计算差值,如果差值小于0或者CAS操作state成功时返回,从上面分析不难发现Semaphore也是支持回环的,每次调用acquire会更新信号量,相当于CyclicBarrier中将信号量重新初始化为备份过的初始值一样。接下来看release方法:
public void release() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
// 尝试释放资源
if (tryReleaseShared(arg)) {
// 资源释放成功则调用part方法唤醒AQS队列里面最先挂起的线程
doReleaseShared();
return true;
}
return false;
}
protected final boolean tryReleaseShared(int releases) {
for (;;) {
// 当前信号量
int current = getState();
// 信号量增加release
int next = current + releases;
// releases<0的情况
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
// CAS更新信号量,保证原子性
if (compareAndSetState(current, next))
return true;
}
}
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!h.compareAndSetWaitStatus(0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}
从上面可以看到release方法是调用releaseShared方法并传参1,在releaseShared方法中调用tryReleaseShared方法来释放资源,通过CAS操作更新信号量,释放资源成功后,调用doReleaseShared方法唤醒AQS队列中最先挂起的线程,结束release。