怎么搞:
state表示锁的状态,可以被获取,state>0表示锁正在被当前线程或者其他线程占用,获取锁采用cas将state设置为1,如果锁被当前线程占用,重入++state,如果获取不到锁,就将当前线程加入到等待队列中,在释放锁的时候,如果state=0,从队列中第一个等待的线程取获取锁。
实现如下:
public class MyLock implements Lock {
//同步器
public final Sync sync;
//搞一个构造方法
public MyLock() {
this.sync = new Sync();
}
static class Sync extends AbstractQueuedSynchronizer {
/*
锁状态定义
*/
public volatile int state = 0;
/*
state内存偏移地址
*/
private static final long stateOffset;
/*
引入魔法类,CAS操作
*/
private static final sun.misc.Unsafe UNSAFE;
/*
当前占用锁的线程
*/
private volatile Thread ownerThread = null;
/*
队列头
*/
public volatile Node head;
/*
队列尾
*/
public volatile Node tail;
/*
队列头偏移量
*/
private static final long headOffset;
/*
队列尾偏移量
*/
private static final long tailOffset;
public Sync() {
head = new Node(null);
tail = new Node(null);
head.next = tail;
}
static {
try {
Field filed = Unsafe.class.getDeclaredField("theUnsafe");
filed.setAccessible(true);
UNSAFE = (Unsafe) filed.get(null);
Class<?> o = Sync.class;
stateOffset = UNSAFE.objectFieldOffset(o.getDeclaredField("state"));
headOffset = UNSAFE.objectFieldOffset(o.getDeclaredField("head"));
tailOffset = UNSAFE.objectFieldOffset(o.getDeclaredField("tail"));
} catch (Exception e) {
throw new Error(e);
}
}
//尝试获取锁
@Override
protected boolean tryAcquire(int arg) {
//当前线程
Thread thread = Thread.currentThread();
//如果state是0,处于空闲,cas成功直接返回true
if (state == 0) {
if (compareAndSetStateV2(0, 1)) {
ownerThread = thread;
return true;
}
}
//如果当前线程占用锁了,++state重入
else if (thread == ownerThread) {
//此处不需要cas,当前线程锁占用,state不会被其他线程修改,不存在线程安全问题
++state;
return true;
}
//其他线程占用返回false
return false;
}
//只有内存中的值与预期值i相同的时候,更新值为arg,原子操作
private boolean compareAndSetStateV2(int i, int arg) {
return UNSAFE.compareAndSwapInt(this, stateOffset, i, arg);
}
//尝试释放锁
@Override
protected boolean tryRelease(int arg) {
//释放锁,通知队列中等待的线程获取锁
if (state == 0) {
ownerThread = null;
Node next = head.next;
if (null != next) {
next.nodeState = 1;
LockSupport.unpark(next.thread);
}
}
return true;
}
public void lock(int arg) {
//获取锁不成功,当前线程加入等待队列,等待被叫醒
if (!tryAcquire(arg)) {
addWaitQue();
}
}
//快速入队
private void addWaitQue() {
//创建节点
Node node = new Node(Thread.currentThread());
addWaitTail(node);
for (; ; ) {
Node headNext = head.next;
if (node.thread == headNext.thread && headNext.nodeState == 1) {
if (tryAcquire(1)) {
ownerThread = Thread.currentThread();
headNext.nodeState = 0;
head = headNext;
return;
}
}
LockSupport.park();
}
}
//添加队尾
private void addWaitTail(Node node) {
for (; ; ) {
Node last = tail;
//如果尾节点中没有保存线程,保存当前线程到尾节点中
if (null == last.thread) {
if (tail.casTailThread(null, Thread.currentThread())) {
break;
}
}
Node next = last.next;
//尾节点中有线程,该节点添加到尾节点之后,并设置该节点为新尾节点
if (null == next) {
if (tail.casNext(null, node)) {
tail = node;
break;
}
}
}
}
private void casTail(Node last, Node next) {
UNSAFE.compareAndSwapObject(this, tailOffset, last, next);
}
public void unlock() {
--state;
tryRelease(1);
}
}
/**
* 内部节点类
*/
static class Node {
/**
* 节点保存的线程
*/
public Thread thread;
/**
* 下一个节点
*/
public volatile Node next;
/**
* 节点的状态
*/
public volatile int nodeState;
/**
* 引入UNSAFE对象,目的是使用其cas方法
*/
private static final sun.misc.Unsafe UNSAFE;
/**
* 下一个节点的内存偏移地址
*/
private static final long nextOffset;
/**
* 保存的线程的内存偏移地址
*/
private static final long threadOffsset;
public Node(Thread thread) {
this.thread = thread;
}
static {
try {
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
UNSAFE = (Unsafe) f.get(null);
Class<?> k = Node.class;
nextOffset = UNSAFE.objectFieldOffset
(k.getDeclaredField("next"));
threadOffsset = UNSAFE.objectFieldOffset
(k.getDeclaredField("thread"));
} catch (Exception e) {
throw new Error(e);
}
}
public boolean casNext(Object o, Node node) {
return UNSAFE.compareAndSwapObject(this, nextOffset, o, node);
}
public boolean casTailThread(Object o, Thread currentThread) {
return UNSAFE.compareAndSwapObject(this, threadOffsset, o, currentThread);
}
}
@Override
public void lock() {
sync.lock(1);
}
@Override
public void lockInterruptibly() throws InterruptedException {
}
@Override
public boolean tryLock() {
return false;
}
@Override
public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
return false;
}
@Override
public void unlock() {
sync.unlock();
}
@Override
public Condition newCondition() {
return null;
}
}