一、使用CAS实现一把锁
锁作用可以抽象理解为避免共享资源被并发访问。按照这条概念我们在JAVA中可以定义一下实现。
- 定义一个锁变量state。
- 当多个线程同时范围同一个共享资源时,我们通过cas保证只有一个线程修改这个锁变量state成功,即获得锁。其他没有获得锁的线程,不断自旋尝试获得锁。
- 当使用完共享资源时,还原state的值,让其他线程获得锁。
定义锁接口
public interface Lock {
void lock();
void unlock();
}
按照上面原则具体实现如下:
public class SpinLock implements Lock {
AtomicInteger state = new AtomicInteger();
@Override
public void lock() {
boolean flag;
do {
//自旋
flag = this.state.compareAndSet(0, 1);
}
while (!flag);
}
@Override
public void unlock() {
state.compareAndSet(1,0);
}
}
测试
public class Main {
static int value = 0;
public static void main(String[] args) throws InterruptedException {
SpinLock spinLock = new SpinLock();
final CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
for (int i = 0; i < 10; i++) {
new Thread(new Runnable() {
public void run() {
try {
cyclicBarrier.await();
} catch (Exception e) {
e.printStackTrace();
}
spinLock.lock();
for (int j = 0; j < 100; j++) {
value++;
}
spinLock.unlock();
}
}).start();
}
TimeUnit.SECONDS.sleep(3);
System.out.println("value: " + value);
}
}
结果
value: 1000
二、实现可重入
当我们判断是同一个线程再次获得锁时,把state自增1。
代表获得锁的次数,即可实现可重入。
为了后面讲解ReentrantLock方便,我们重构代码。定义CustomAbstractQueuedSynchronizer抽象类并继承AbstractOwnableSynchronizer。AbstractOwnableSynchronizer是JDK提供的抽象类,用于设置和获取当前获得锁的线程。为了使用state方便,改用unsafe对state进行操作。
public abstract class AbstractOwnableSynchronizer
implements java.io.Serializable {
private static final long serialVersionUID = 3737899427754241961L;
protected AbstractOwnableSynchronizer() { }
private transient Thread exclusiveOwnerThread;
protected final void setExclusiveOwnerThread(Thread thread) {
exclusiveOwnerThread = thread;
}
protected final Thread getExclusiveOwnerThread() {
return exclusiveOwnerThread;
}
}
public abstract class CustomAbstractQueuedSynchronizer extends AbstractOwnableSynchronizer {
/**
* The synchronization state.
*/
private volatile int state;
private static final long stateOffset;
static {
try {
Field field =
Unsafe.class.getDeclaredField("theUnsafe");
field.setAccessible(true);
unsafe = (Unsafe) field.get(null);
stateOffset = unsafe.objectFieldOffset
(CustomAbstractQueuedSynchronizer.class.getDeclaredField("state"));
} catch (Exception ex) { throw new Error(ex); }
}
protected final int getState() {
return state;
}
protected final void setState(int newState) {
state = newState;
}
protected final boolean compareAndSetState(int expect, int update) {
// See below for intrinsics setup to support this
return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
}
}
重入锁的实现如下:
实现逻辑很简单,当有线程获得锁时调用setExclusiveOwnerThread方法设置当前获得锁的线程。当cas获得锁失败,判断是否是同一个线程再次获得锁,如果是则state加1。释放锁的时state减1。如果state为0,清空当前获得锁的线程。
public class SpinReentrantLock implements Lock {
private Sync sync;
public SpinReentrantLock() {
sync = new SimpleNonfairSync();
}
abstract static class Sync extends CustomAbstractQueuedSynchronizer {
protected abstract void lock();
protected abstract void unlock();
}
static final class SimpleNonfairSync extends Sync {
@Override
protected void lock() {
boolean flag;
do {
Thread current = Thread.currentThread();
if (flag = compareAndSetState(0, 1)) {
//System.out.println(current.getName() + " 获得锁");
setExclusiveOwnerThread(current);
} else if (current == getExclusiveOwnerThread()) {
int c = getState();
int nextc = c + 1;
if (nextc < 0) {
// overflow
throw new Error("Maximum lock count exceeded");
}
//System.out.println(current.getName() + " 重入state:" + nextc);
setState(nextc);
flag = true;
}
}
while (!flag);
}
@Override
protected void unlock() {
int c = getState() - 1;
if (Thread.currentThread() != getExclusiveOwnerThread())
throw new IllegalMonitorStateException();
if (c == 0) {
setExclusiveOwnerThread(null);
}
// System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
setState(c);
}
}
@Override
public void lock() {
sync.lock();
}
@Override
public void unlock() {
sync.unlock();
}
}
三、队列
当并发比较高的时候大量的CAS失败可能导致SpinReentrantLock锁的效率比较低,且自旋比较消耗CUP。所以当线程获取锁失败,我们把线程放入队列中并挂起。当线程释放锁时唤起挂起的线程。
在抽象类CustomAbstractQueuedSynchronizer中加入一个线程安全的链表threadQueue
用于存放被挂起的线程。head变量的作用是记录队列的头结点。acquire方法使用的是模板设计模式,tryAcquire获得锁的逻辑,交由子类实现,当线程获得锁失败,调用LockSupport.park(this)挂起线程,如果获得锁成功线程出队,并更新head。完整代码如下
public abstract class CustomAbstractQueuedSynchronizer extends AbstractOwnableSynchronizer {
/**
* The synchronization state.
*/
private volatile int state;
private static final Unsafe unsafe;
private static final long stateOffset;
private transient volatile Thread head;
protected Queue<Thread> threadQueue = new ConcurrentLinkedQueue<>();
static {
try {
Field field =
Unsafe.class.getDeclaredField("theUnsafe");
field.setAccessible(true);
unsafe = (Unsafe) field.get(null);
stateOffset = unsafe.objectFieldOffset
(CustomAbstractQueuedSynchronizer.class.getDeclaredField("state"));
} catch (Exception ex) {
throw new Error(ex);
}
}
protected final int getState() {
return state;
}
protected final void setState(int newState) {
state = newState;
}
protected final boolean compareAndSetState(int expect, int update) {
// See below for intrinsics setup to support this
return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
}
public Thread getHead() {
return head;
}
public void setHead(Thread head) {
this.head = head;
}
/**
* 获取锁的逻辑,交由子类实现
* @param arg
* @return
*/
protected boolean tryAcquire(int arg) {
throw new UnsupportedOperationException();
}
/**
* 判断队列中是否为空
* @return
*/
public final boolean hasQueuedPredecessors() {
return threadQueue.isEmpty();
}
/**
* 释放锁的逻辑,交由子类实现
* @param arg
* @return
*/
protected boolean tryRelease(int arg) {
throw new UnsupportedOperationException();
}
/**
* 获得锁和线程入队,以及唤醒后的逻辑
* @param arg
*/
public final void acquire(int arg) {
Thread current = Thread.currentThread();
//调用tryAcquire获得锁失败,线程放入队列中
if (!tryAcquire(arg) && threadQueue.add(current)) {
if (getHead() == null) {
setHead(threadQueue.peek());
}
//只要获得锁成功才能跳出循环
for (; ; ) {
if (current == getHead() && tryAcquire(arg)) {
//任务出队
threadQueue.poll();
//头部元素出队之后,更新头元素
setHead(threadQueue.peek());
return;
}
// System.out.println("挂起线程: " + current.getName() + " size: " + Arrays.toString(threadQueue.toArray()));
//获得锁失败,挂起线程
LockSupport.park(this);
}
}
}
}
Sync的unlock方法逻辑如下
- 重写tryRelease方法,当sate等于0的时候返回true表示释放锁成功。
- 如果释放锁成功,则调用threadQueue.peek()方法获得头结点,并通过LockSupport.unpark(poll)唤起线程。
abstract static class Sync extends CustomAbstractQueuedSynchronizer {
protected abstract void lock();
protected void unlock() {
if (tryRelease(1)){
Thread poll = threadQueue.peek();
if (poll != null) {
//System.out.println(Thread.currentThread().getName() + " 唤起线程: " + poll.getName() + " size: "+threadQueue.size());
LockSupport.unpark(poll);
} else {
setHead(null);
}
}
}
@Override
protected boolean tryRelease(int arg) {
int c = getState() - 1;
if (Thread.currentThread() != getExclusiveOwnerThread()){
throw new IllegalMonitorStateException();
}
boolean free = false;
if (c == 0) {
free=true;
setExclusiveOwnerThread(null);
}
// System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
setState(c);
return free;
}
}
NonfairSync类方法如下。
正如上面提到acquire使用的是模板设计模式,获得锁的逻辑由tryAcquire实现。(tryAcquire的实现是一种非公平的模式)
static final class NonfairSync extends Sync {
@Override
protected void lock() {
Thread current = Thread.currentThread();
if (compareAndSetState(0, 1)) {
// System.out.println(current.getName() + " 获得锁");
setExclusiveOwnerThread(current);
}else {
acquire(1);
}
}
@Override
protected boolean tryAcquire(int arg) {
return nonfairTryAcquire(arg);
}
final boolean nonfairTryAcquire(int acquires) {
final Thread current = Thread.currentThread();
int c = getState();
if (c == 0) {
if (compareAndSetState(0, acquires)) {
// System.out.println(current.getName() + " 获得锁");
setExclusiveOwnerThread(current);
return true;
}
} else if (current == getExclusiveOwnerThread()) {
int nextc = c + acquires;
if (nextc < 0) // overflow
throw new Error("Maximum lock count exceeded");
// System.out.println(current.getName() + " 重入state:" + nextc);
setState(nextc);
return true;
}
return false;
}
}
测试
public class Main {
static int value = 0;
public static void main(String[] args) throws InterruptedException {
SpinReentrantLock spinReentrantLock = new SpinReentrantLock(true);
final CyclicBarrier cyclicBarrier = new CyclicBarrier(1000);
final CountDownLatch countDownLatch = new CountDownLatch(1000);
long start = System.currentTimeMillis();
for (int i = 0; i < 1000 ; i++) {
new Thread(new Runnable() {
public void run() {
try {
cyclicBarrier.await();
} catch (Exception e) {
e.printStackTrace();
}
spinReentrantLock.lock();
// System.out.println(Thread.currentThread().getName() + " 获得锁");
for (int j = 0; j < 1000; j++) {
value++;
}
spinReentrantLock.unlock();
countDownLatch.countDown();
}
},"thread:"+i).start();
}
countDownLatch.await();
long end = System.currentTimeMillis();
System.out.println("执行时间:" + (end - start));
System.out.println("value: " + value);
}
}
执行时间:70
value: 1000000
四、公平锁
队列中的任务线程优先执行,后到的线程只能只能排队等待。代码实现如下:
可以看到相对于非公平锁,公平锁的实现只是在获得锁前,调用hasQueuedPredecessors方法检查队列中是否有值。
static final class FairSync extends Sync {
@Override
protected void lock() {
acquire(1);
}
protected final boolean tryAcquire(int acquires) {
final Thread current = Thread.currentThread();
int c = getState();
if (c == 0) {
if (!hasQueuedPredecessors() &&
compareAndSetState(0, acquires)) {
setExclusiveOwnerThread(current);
return true;
}
} else if (current == getExclusiveOwnerThread()) {
int nextc = c + acquires;
if (nextc < 0)
throw new Error("Maximum lock count exceeded");
setState(nextc);
return true;
}
return false;
}
}
<font color=rgb(63,63,63) face="microsoft yahei" size=4>测试
public class Main {
static int value = 0;
public static void main(String[] args) throws InterruptedException {
SpinReentrantLock spinReentrantLock = new SpinReentrantLock(true);
final CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
final CountDownLatch countDownLatch = new CountDownLatch(10);
long start = System.currentTimeMillis();
for (int i = 0; i < 10 ; i++) {
new Thread(new Runnable() {
public void run() {
try {
cyclicBarrier.await();
} catch (Exception e) {
e.printStackTrace();
}
spinReentrantLock.lock();
// System.out.println(Thread.currentThread().getName() + " 获得锁");
for (int j = 0; j < 1000; j++) {
value++;
}
spinReentrantLock.unlock();
countDownLatch.countDown();
}
},"thread:"+i).start();
}
countDownLatch.await();
long end = System.currentTimeMillis();
System.out.println("执行时间:" + (end - start));
System.out.println("value: " + value);
}
}
结果,可以看到任务都是按照入队的顺序执行。
thread:0获得锁
挂起线程: thread:6 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main]]
挂起线程: thread:5 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main]]
挂起线程: thread:1 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
挂起线程: thread:2 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
挂起线程: thread:3 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
挂起线程: thread:4 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main]]
挂起线程: thread:7 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main]]
挂起线程: thread:9 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main], Thread[thread:8,5,main]]
挂起线程: thread:8 threadQueue: [Thread[thread:0,5,main], Thread[thread:9,5,main], Thread[thread:1,5,main], Thread[thread:2,5,main], Thread[thread:3,5,main], Thread[thread:4,5,main], Thread[thread:5,5,main], Thread[thread:6,5,main], Thread[thread:7,5,main], Thread[thread:8,5,main]]
thread:9获得锁
thread:1获得锁
thread:2获得锁
thread:3获得锁
thread:4获得锁
thread:5获得锁
thread:6获得锁
thread:7获得锁
thread:8获得锁
执行时间:3
value: 10000
五、总结:
最后附上SpinReentrantLock完整实现。
public class SpinReentrantLock implements Lock {
private Sync sync;
public SpinReentrantLock() {
sync = new NonfairSync();
}
public SpinReentrantLock(boolean fair) {
if (fair){
sync = new FairSync();
}else {
sync = new NonfairSync();
}
}
static final class FairSync extends Sync {
@Override
protected void lock() {
acquire(1);
}
// public final void acquire(int arg) {
// Thread current = Thread.currentThread();
// if (!tryAcquire(arg) &&threadQueue.add(current)) {
// if (getHead() == null) {
// setHead(threadQueue.peek());
// }
// for (; ; ) {
// if (current == getHead() && tryAcquire(arg)) {
// threadQueue.poll();
// //头部元素出队之后,更新头元素
// setHead(threadQueue.peek());
// return;
// }
// System.out.println("挂起线程: " +current.getName()+" size: "+ Arrays.toString(threadQueue.toArray()));
// LockSupport.park(this);
// }
// }
// }
protected final boolean tryAcquire(int acquires) {
final Thread current = Thread.currentThread();
int c = getState();
if (c == 0) {
if (!hasQueuedPredecessors() &&
compareAndSetState(0, acquires)) {
setExclusiveOwnerThread(current);
return true;
}
} else if (current == getExclusiveOwnerThread()) {
int nextc = c + acquires;
if (nextc < 0)
throw new Error("Maximum lock count exceeded");
setState(nextc);
return true;
}
return false;
}
}
abstract static class Sync extends CustomAbstractQueuedSynchronizer {
protected abstract void lock();
protected void unlock() {
if (tryRelease(1)){
Thread poll = threadQueue.peek();
if (poll != null) {
//System.out.println(Thread.currentThread().getName() + " 唤起线程: " + poll.getName() + " size: "+threadQueue.size());
LockSupport.unpark(poll);
} else {
setHead(null);
}
}
}
@Override
protected boolean tryRelease(int arg) {
int c = getState() - 1;
if (Thread.currentThread() != getExclusiveOwnerThread()){
throw new IllegalMonitorStateException();
}
boolean free = false;
if (c == 0) {
free=true;
setExclusiveOwnerThread(null);
}
System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
setState(c);
return free;
}
}
static final class NonfairSync extends Sync {
@Override
protected void lock() {
Thread current = Thread.currentThread();
if (compareAndSetState(0, 1)) {
// System.out.println(current.getName() + " 获得锁");
setExclusiveOwnerThread(current);
}else {
acquire(1);
}
// else if (!tryAcquire(1) && threadQueue.add(current)) {
// //每次都是从头部元素开始唤起
// if (getHead() == null) {
// setHead(threadQueue.peek());
// }
// for (; ; ) {
// if (current == getHead() && tryAcquire(1)) {
// threadQueue.poll();
// //头部元素出队之后,更新头元素
// setHead(threadQueue.peek());
// return;
// }
// // System.out.println("挂起线程: " +current.getName()+" size: "+ Arrays.toString(threadQueue.toArray()));
// LockSupport.park(this);
// }
// }
}
@Override
protected boolean tryAcquire(int arg) {
return nonfairTryAcquire(arg);
}
final boolean nonfairTryAcquire(int acquires) {
final Thread current = Thread.currentThread();
int c = getState();
if (c == 0) {
if (compareAndSetState(0, acquires)) {
// System.out.println(current.getName() + " 获得锁");
setExclusiveOwnerThread(current);
return true;
}
} else if (current == getExclusiveOwnerThread()) {
int nextc = c + acquires;
if (nextc < 0) // overflow
throw new Error("Maximum lock count exceeded");
// System.out.println(current.getName() + " 重入state:" + nextc);
setState(nextc);
return true;
}
return false;
}
}
static final class SimpleNonfairSync extends Sync {
@Override
protected void lock() {
boolean flag;
do {
Thread current = Thread.currentThread();
if (flag = compareAndSetState(0, 1)) {
System.out.println(current.getName() + " 获得锁");
setExclusiveOwnerThread(current);
} else if (current == getExclusiveOwnerThread()) {
int c = getState();
int nextc = c + 1;
if (nextc < 0) {
// overflow
throw new Error("Maximum lock count exceeded");
}
System.out.println(current.getName() + " 重入state:" + nextc);
setState(nextc);
flag = true;
}
}
while (!flag);
}
@Override
protected void unlock() {
int c = getState() - 1;
if (Thread.currentThread() != getExclusiveOwnerThread())
throw new IllegalMonitorStateException();
if (c == 0) {
setExclusiveOwnerThread(null);
}
// System.out.println(Thread.currentThread().getName() + " 释放锁state: " + c);
setState(c);
}
}
@Override
public void lock() {
sync.lock();
}
@Override
public void unlock() {
sync.unlock();
}
}
上述实现的锁功能还比较简单,比如暂时还不支持响应中断,或者超时挂起等,但实现起来并不难,这里就不在赘述。
下一节我们探讨线程并发工具的基石AQS。