最近一段时间在看一本书《Java并发编程的艺术》,在P164讲到了关于ConcurrentLinkedQueue的源码分析,但是这部分源码非常复杂,于是我又顺手看了一下IDEA的Java源码,发现在Java8中,该部分的源码已经被更新过了,正好读一读顺带做个笔记。
基本介绍
ConcurrentLinkedQueue是一个列表实现,包括一个head和tail引用,该类的初始化过程中,头尾引用都被初始化成一个空的Node,下面我们可以看到相关代码:
public class ConcurrentLinkedQueue<E> extends AbstractQueue<E>
implements Queue<E>, java.io.Serializable {
private static class Node<E> {
volatile E item;
volatile Node<E> next;
}
private transient volatile Node<E> head;
private transient volatile Node<E> tail;
public ConcurrentLinkedQueue() {
head = tail = new Node<E>(null);
}
}
入队流程
单线程下的入队流程为:
- 将新节点加入到tail引用的next中
- 将新节点赋值给tail引用
但是在多线程环境中,需要保障其他线程入队和出队不受影响,ConcurrentLinkedQueue由CAS算法实现了无锁入队,下面是加入节点的关键代码:
public boolean offer(E e) {
checkNotNull(e);
final Node<E> newNode = new Node<E>(e);
// 循环开始,p和t都指向tail,q指向tail的next
for (Node<E> t = tail, p = t;;) {
Node<E> q = p.next;
if (q == null) {
// q为null代表目前tail后面没有其他线程插入的节点,即p确实是最后的节点
if (p.casNext(null, newNode)) {
// 这里casNext函数的作用是当p的next节点为null时,用newNode更新p的next节点,更新成功返回true
// 如果casNext更新成功,证明newNode已经成功插入到队尾
if (p != t)
// 这一步判断表明,t即tail已经不是真正的队尾引用,这是减少cas操作的一步优化
// 这里casTail函数的作用是当tail与t相等时,用newNode更新tail,在这里CAS失败也没有关系
casTail(t, newNode);
return true;
}
// 如果casNext更新失败,则重新将p的next赋值给q
}
else if (p == q)
// 当p==q只有一种情况,即p==p.next,在这种情况下就表明当前节点已经离队,因为在出队操作之后,ConcurrentLinkedQueue会将出队节点的next设为它本身
// 在遇到当前节点已经是出队节点的情况下,表明当前节点已经在head之前,因此根据如下逻辑进行更新当前节点:1、如果tail已经更新,那么将当前节点设为tail;2、否则,将当前节点设为head,因为不能保证tail指向的节点是否已经离队
p = (t != (t = tail)) ? t : head;
else
// 当tail更新且p不在tail时,用tail更新p,否则用q更新p
p = (p != t && t != (t = tail)) ? t : q;
}
}
如果觉得上述方法过于复杂,我们可以用一种更简单的方案来进行结果相同的操作:
public boolean offer(E e) {
checkNotNull(e);
final Node<E> newNode = new Node<E>(e);
for (; ; ) {
Node<E> t = tail;
if (t.casNext(null, newNode)) {
// 参照单线程的入队流程,casNext成功表明newNode已经成功插入到了队列里
// 如果casTail失败了也没有关系,失败了证明有其他的线程在进行casTail,至少有一根线程可以成功
casTail(t, newNode);
return true;
}
}
}
而在JDK源码中,加入了一步优化,这步优化是:在插入一个新节点时,不着急将tail指向这个新节点,而是在插入第二个新节点的时候,才对tail进行cas操作。
这样做会导致两个问题:
- tail并不在保持原有的一定指向队尾的性质;
- 从tail开始需要进过几步查找next才能寻找到真正的队尾;
但是这样做有一个好处:减少了至少一半的cas操作,虽然增加了普通的赋值操作,但是在多线程情况下cas操作的耗时要远远大于一般赋值操作的耗时,因此这部分优化可以增大该容器类的并发量。而剩下部分的判断都是为了在进行这一步优化的情况下,保证程序的正确性所做的。
出队流程
单线程情况下的出队流程为:
- 如果head==tail,证明队列为空,返回null
- 将队首元素的值取出,作为返回值
- 将head指向head.next
如果按照这种思路,我们可以直接写出一个简单写法的无锁出队方案:
public E poll() {
for (; ; ) {
Node<E> h = head;
if (h.next == null) {
return null;
} else {
if (casHead(h, h.next)) {
if (h.next != null)
return h.next.item;
}
}
}
}
我们再来看JDK源码中的poll函数实现,在这个poll函数中,使用了和offer函数中类似的优化方式,在出队的时候并不着急更新head的值,而是缓慢更新,然后用一部分操作来保证出队的正确性:
public E poll() {
restartFromHead:
for (; ; ) {
for (Node<E> h = head, p = h, q; ; ) {
E item = p.item;
if (item != null && p.casItem(item, null)) {
if (p != h)
updateHead(h, ((q = p.next) != null) ? q : p);
return item;
} else if ((q = p.next) == null) {
updateHead(h, p);
return null;
} else if (p == q)
continue restartFromHead;
else
p = q;
}
}
}
性能测试
这里不光是性能测试,同样有针对上述两种简单的无锁入队和出队的正确性测试。我分别开了2根入队线程和2根出队线程,每根入队线程循环入队1000W的数据,下面展示了测试结果(因为我的电脑是4核i5,比较弱鸡,如果线程开多了那么大量的时间都在线程切换上,测试结果就不准确了):
使用JDK源码
Test Started: 11:15 25:839
Get thread finished, Total: 10809006
Get thread finished, Total: 9190994
Test Finished: 11:15 31:487
Total Time Cost: 5s 648ms
使用自定义的offer函数
Test Started: 11:17 36:963
Get thread finished, Total: 9335745
Get thread finished, Total: 10664255
Test Finished: 11:17 41:627
Total Time Cost: 4s 664ms
使用自定义的poll函数
Test Started: 11:18 17:412
Get thread finished, Total: 9714954
Get thread finished, Total: 10285046
Test Finished: 11:18 21:669
Total Time Cost: 4s 257ms
同时使用自定义的offer和poll函数
Test Started: 11:18 51:663
Get thread finished, Total: 10219132
Get thread finished, Total: 9780868
Test Finished: 11:18 56:602
Total Time Cost: 4s 939ms
有点尴尬的是好像优化过的源码是跑的最慢的,应该和我只有2根读写线程有关,争抢的情况比较少,争抢情况越严重,线程越多,源码的速度应该是更快的。如果谁有更好的机器可以拿代码试一下,下面是我的测试代码:
public class TestQueue {
private static int TOTAL_COUNT = 10000000;
private static int TOTAL_WRITE = 2;
private static int TOTAL_READ = 2;
private static SimpleDateFormat DATE_FORMAT = new SimpleDateFormat("HH:mm ss:SSS");
public static void main(String[] args) {
AtomicInteger flag = new AtomicInteger(0);
ConcurrentHashMap<Integer, AtomicInteger> total = new ConcurrentHashMap<>(TOTAL_COUNT);
for (int i = 0; i != TOTAL_COUNT; i++) {
total.put(i, new AtomicInteger(0));
}
CustomQueue<Integer> customQueue = new CustomQueue<>();
ExecutorService executor = Executors.newCachedThreadPool();
Date startTime = new Date();
System.out.println("Test Started: " + DATE_FORMAT.format(startTime));
for (int i = 0; i != TOTAL_WRITE; i++) {
executor.execute(new Runnable() {
@Override
public void run() {
for (int i = 0; i != TOTAL_COUNT; i++) {
customQueue.add(i);
}
}
});
}
for (int i = 0; i != TOTAL_READ; i++) {
executor.execute(new Runnable() {
@Override
public void run() {
int sum = 0;
while (flag.get() != TOTAL_WRITE * TOTAL_COUNT) {
Integer num = customQueue.poll();
if (num != null) {
sum++;
flag.incrementAndGet();
total.get(num).incrementAndGet();
}
}
System.out.println("Get thread finished, Total: " + sum);
}
});
}
executor.shutdown();
try {
executor.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
Date endTime = new Date();
long totalTime = endTime.getTime() - startTime.getTime();
for (int i = 0; i != TOTAL_COUNT; i++) {
if (total.get(i).get() != TOTAL_WRITE) {
System.out.println("Test Failed: " + i + " " + total.get(i));
break;
}
}
System.out.println("Test Finished: " + DATE_FORMAT.format(endTime));
System.out.printf("Total Time Cost: %ds %dms", totalTime / 1000, totalTime % 1000);
} catch (InterruptedException e) {
System.out.println("Failure: " + flag.get());
e.printStackTrace();
}
}
}