分析一下ConcurrentHashMap的成员变量:
//默认ConcurrentHashMap的大小
static final int DEFAULT_INITIAL_CAPACITY = 16;
//默认加载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;
//默认支持的最大并发数
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
//ConcurrentHashMap的最大容量
static final int MAXIMUM_CAPACITY = 1 << 30;
//segment分段锁的初始容量也是最小容量(即segment中HashEntry的初始容量)
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
//最大segment数(segments数组的最大长度)
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative
//重试加锁次数
static final int RETRIES_BEFORE_LOCK = 2;
//分段锁的掩码,用来计算key所在的segment在segments的数组下标
final int segmentMask;
//分段锁偏移量,用来查找segment在内存中的位置
final int segmentShift;
//segment数组
final Segment<K,V>[] segments;
总结:ConcurrentHashMap中包含一个segment数组。
分析一下Segment数组中segment对象:
static final class Segment<K,V> extends ReentrantLock implements Serializable {
private static final long serialVersionUID = 2249069246763182397L;
//自旋最大次数
static final int MAX_SCAN_RETRIES =
Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
//HashEntry数组
transient volatile HashEntry<K,V>[] table;
//包含的元素总数
transient int count;
//修改的次数
transient int modCount;
//元素的阀值
transient int threshold;
//加载因子
final float loadFactor;
}
总结:segment中包含一个HashEntry数组。
分析一下HashEntry数组中的HashEntry对象:
static final class HashEntry<K,V> {
final int hash;
final K key;
volatile V value;
volatile HashEntry<K,V> next;
HashEntry(int hash, K key, V value, HashEntry<K,V> next) {
this.hash = hash;
this.key = key;
this.value = value;
this.next = next;
}
}
整体结构如下图所示:
内部结构看完了,看一下ConcurrentHashMap是如何初始化的:
//假如new ConcurrentHashMap<String,Object>()创建
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
//保证最大并发不超过MAX_SEGMENTS(1 << 16)
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
//循环判断保证ssize是2的幂(即Segment数组的长度)
//循环完sshift = 4,ssize = 16
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;//ssize = size << 1
}
//segmentShift最后为16
this.segmentShift = 32 - sshift;
//segmentMask最后为15
this.segmentMask = ssize - 1;
//ConcurrentHashMap初始容量不超过MAXIMUM_CAPACITY(1 << 30)
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
//根据ConcurrentHashMap总容量initialCapacity除以Segments[]数组的长度得到单个分段锁segment中HashEntry[]的大小
int c = initialCapacity / ssize;
//保证分段锁segment的总容量c不小于初始的容量
if (c * ssize < initialCapacity)
++c;
//cap为Segments[]数组中分段锁segment的HashEntry[]的大小,保证为2的幂
int cap = MIN_SEGMENT_TABLE_CAPACITY;
while (cap < c)
cap <<= 1;
// create segments and segments[0]
//记住这里★
//创建一个s0,然后初始化到Segments[0]中
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
//创建Segments[]数组
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
this.segments = ss;
}
总结:这里就是根据用户创建new ConcurrentHashMap(....)时,传递的值或者默认值进行ConcurrentHashMap的初始化。创建一个Segments[]数组,最大数组长度是16,然后再初始化Segments[0]位置的值。
初始化完,就是往map中放入数据,看一下是如何放的:
public V put(K key, V value) {
Segment<K,V> s;
//value不能为空
if (value == null)
throw new NullPointerException();
//计算key的HASH值
int hash = hash(key);
//无符号右移segmentShift位(默认16),然后 & segmentMask(默认15)得到segment在内存中的地址
int j = (hash >>> segmentShift) & segmentMask;
if ((s = (Segment<K,V>)UNSAFE.getObject
(segments, (j << SSHIFT) + SBASE)) == null) //如果获取到的segment为null
s = ensureSegment(j);//初始化segment
//放值
return s.put(key, hash, value, false);
}
总结:根据hash值获取分段锁segment的内存地址,如果获取到的segment为null,则初始化。否则就是放值。
看一下初始化segment的方法s = ensureSegment(j):
private Segment<K,V> ensureSegment(int k) {
//拿到Segments[]数组
final Segment<K,V>[] ss = this.segments;
//获取k所在的segment在内存中的偏移量
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
//获取k所在的segmen,判断segmen是否为null
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
//初始化一个segment等于Segments[0]
//Segments[0]在初始化ConcurrentHashMap时,已经初始化了一个segment放到Segments[0],用★标识的地方。
Segment<K,V> proto = ss[0]; // use segment 0 as prototype
//然后就是获取Segments[0]中HashEntry数组的数据
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
//初始化一个HashEntry数组,大小和Segments[0]中的HashEntry一样。
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
//再次获取k所在的segment(防止其他线程已经初始化好)
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
//如果还是null,创建一个segment并通过cas设置到对应的位置
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
总结:根据k的hash值,获取segment,如何获取不到则就初始化一个和Segment[0]一样大小的segment。并通过CAS操作,初始化到Segments[]中。
获取到key所在的segment之后,就是调用s.put(key, hash, value, false)方法:
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
//尝试获取segment的锁
//失败就通过自旋去获取锁,超过自旋最大次数时,就将操作放入到Lock锁的队列中
HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
V oldValue;
//走到这里的线程一定获取到锁,没获取到的都放到了Lock的队列中
try {
//拿到segment中的HashEntry数组
HashEntry<K,V>[] tab = table;
//得到key所在HashEntry数组的下标
int index = (tab.length - 1) & hash;
//获取key所在的HashEntry数组某个位置的头结点(HashEntry是一个数组加链表结构)
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
if (e != null) {//HashEntry头节点不为null
K k;
//找到传入的值
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
//是否替换value,默认替换(有个putIfAbsent(key,value)就是不替换)
if (!onlyIfAbsent) {
e.value = value;
++modCount;
}
break;
}
e = e.next;
} else {//如果没有找到或HashEntry头节点为null
if (node != null)//判断node是否已经初始化(自旋获取锁做的这些操作,这个node要么为null要么就是新建的HashEntry)
node.setNext(first);//头插法
else
node = new HashEntry<K,V>(hash, key, value, first);//头插法
int c = count + 1;//修改segment的元素总数
if (c > threshold && tab.length < MAXIMUM_CAPACITY)//如果超过segment的阀值并且segment没有超过最大容量,rehash
rehash(node);
else
setEntryAt(tab, index, node);//更新HashEntry数组
++modCount;
count = c;
oldValue = null;
break;
}
}
}
finally {
unlock();//释放锁
}
return oldValue;
}
总结:首先是尝试获取segment的锁,获取到向下执行,获取不到就通过自旋操作去获取锁(下面说自旋操作scanAndLockForPut(key, hash, value))。拿到锁之后,找到k所在的HashEntry数组的下标,然后获取头节点。向下遍历头结点,查找到就更新(默认),没查找到就新建一个HashEntry,通过头插法放入HashEntry数组,最后更新HashEntry。
put方法首先要获取segment的锁,获取失败就去通过自旋的方式再次尝试获取锁scanAndLockForPut(key, hash, value):
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
//获取k所在的segment中的HashEntry的头节点(segment中放得是HashEntry数组,HashEntry又是个链表结构)
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node
while (!tryLock()) {//尝试获取k所在segment的锁。成功就直接返回、失败进入while循环进行自旋尝试获取锁
HashEntry<K,V> f; // to recheck first below
if (retries < 0) {
if (e == null) {//所在HashEntry链表不存在,则根据传过来的key-value创建一个HashEntry
if (node == null) // speculatively create node
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
} else if (key.equals(e.key))//找到要放得值,则设置segment重试次数为0
retries = 0;
else //从头节点往下寻找key对应的HashEntry
e = e.next;
} else if (++retries > MAX_SCAN_RETRIES) {//超过最大重试次数就将当前操作放入到Lock的队列中
lock();
break;
} else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {//如果retries为偶数,就重新获取HashEntry链表的头结点
e = first = f; // re-traverse if entry changed
retries = -1;
}
}
return node;
}
总结:拿到k所在的segment的HashEntry的头节点(想想segment中的数据结构),首先尝试获取segment的锁。
1.1、获取失败
1.1.1、
1.1.1.1、如果头节点是为null,则将传进来的key-value新建一个HashEntry,同时设置retries为0,从而再次去尝试获取segment的锁
1.1.1.2、如果头节点不为null,并且头节点的key == 传进来的key,设置retries为0,从而再次去尝试获取segment的锁
1.1.1.3、如果头节点不为null,并且头节点的key != 传进来的key 则获取头节点的下一个节点,再次尝试获取segment的锁
1.1.2、如果retries > MAX_SCAN_RETRIES 则调用reentrantlock的lock方法,将当前操作放入到lock队列中.跳出while循环
1.1.3、如果retries为偶数,则再次获取对应segment的头节点,判断是否有变化,有就重新获取头结点
1.2、获取成功
1.2.1、直接返回null
如果插入数据,就会使当前segment达到阀值,则判断是否符合rehash的条件,符合就进行rehash(node):
private void rehash(HashEntry<K,V> node) {//node为待新加入的节点
//获取当前segment的HashEntry数组
HashEntry<K,V>[] oldTable = table;
//获取原HashEntry数组的长度
int oldCapacity = oldTable.length;
//新HashEntry数组的长度 = 原HashEntry数组长度*2
int newCapacity = oldCapacity << 1;
//重新计算新HashEntry数组的阀值
threshold = (int)(newCapacity * loadFactor);
//创建新HashEntry数组
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity];
//新HashEntry数组的掩码(用来计算元素在新数组中的下标位置)
int sizeMask = newCapacity - 1;
//遍历原HashEntry数组中的元素(segment中是一个HashEntry数组,结构是数组加链表)
for (int i = 0; i < oldCapacity ; i++) {
//获取原HashEntry[i]位置的头节点HashEntry
HashEntry<K,V> e = oldTable[i];
//判断当前节点是否为空
if (e != null) {
//获取当前节点的下一个节点
HashEntry<K,V> next = e.next;
//重新计算当前节点在新HashEntry数组中的下标
int idx = e.hash & sizeMask;
//如果下一个节点为空,则原HashEntry数组这个位置就一个元素,直接放到新HashEntry数组就行
if (next == null) // Single node on list
newTable[idx] = e;
//下个节点不为空
else { // Reuse consecutive sequence at same slot
//这里就是找到某个节点,从这个节点往下,都会在新数组的某个坐标下,形成新的链表
//原:
//HashEntry:[0] [1]...
//里面的值: 1
// 2
// 3
// 4
// 5
// 6
//新:
//HashEntry:[0] [1]...
//里面的值: 4
// 5
// 6
//这里最后lastRun就是4 ,idx就是新数组的下标0
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
//然后把4节点放到新数组中,这样4,5,6就都过去了
newTable[lastIdx] = lastRun;
// Clone remaining nodes
//剩下的就是解决1,2,3了,遍历原HashEntry数组,直到等于4为止
//这些节点采用头插法放到新HashEntry数组中
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];//头插法
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
//将待新加入的元素放到新数组(头插法)
int nodeIndex = node.hash & sizeMask; // add the new node
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
//segment的数组指到新数组
table = newTable;
}
总结:原HashEntry数组扩大一倍,然后遍历原HashEntry数组,找到某个节点lastRun,从这个节点开始往下,都会放到新HashEntry数组的某个槽下面。接下来就是遍历剩下的数据,然后采用头插法,放到新数组中。最后就是将待新加人的节点,放到新数组中。
放数据搞定,接下来就是拿数据:
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
//key的hash值
int h = hash(key);
//计算出k所在的segment所在的内存地址
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
//通过CAS操作获取segment,判断都不为null
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
//再通过CAS操作,获取的key所在的HashEntry[]数组的下标,获取到就返回,没有就返回null
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
总结:get操作比较简单,就是一路CAS操作,遍历拿值。
放数据,拿数据都搞定,然后就是删数据:
public V remove(Object key) {
int hash = hash(key);
//key的hash值,获取segment
Segment<K,V> s = segmentForHash(hash);
return s == null ? null : s.remove(key, hash, null);
}
final V remove(Object key, int hash, Object value) {
//尝试获取segment锁
if (!tryLock())
//失败就自旋
scanAndLock(key, hash);
//执行到这里一定获取了segment锁
V oldValue = null;
try {
//获取segment的HashEntry[]数组
//获取key所在的下标,然后获取头结点
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
HashEntry<K,V> e = entryAt(tab, index);
HashEntry<K,V> pred = null;
while (e != null) {
K k;
HashEntry<K,V> next = e.next;
//判断当前节点e是不是要删除的数据
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {//key相同
V v = e.value;
//如果没有传value,或者value相同
if (value == null || value == v || value.equals(v)) {
if (pred == null)//要删除节点的前一个节点为null
setEntryAt(tab, index, next);//直接赋值
else
pred.setNext(next);//直接将当前节点的前一个节点的next设置成当前节点的下一个节点。
++modCount;
--count;
oldValue = v;
}
break;
}
//前节点不是要找的节点,遍历下一个
pred = e;
e = next;
}
}
finally {
unlock();
}
return oldValue;
}
总结:找到key所在HashEntry[]数组的下标,然后遍历链表,找到节点。将当前节点的上一个节点的next设置成当前节点的下一个节点。
最后看一下是如何获取ConcurrentHashMap的size的:
public int size() {
// Try a few times to get accurate count. On failure due to
// continuous async changes in table, resort to locking.
final Segment<K,V>[] segments = this.segments;
int size;
boolean overflow; // true if size overflows 32 bits
long sum; // sum of modCounts
long last = 0L; // previous sum
int retries = -1; // first iteration isn't retry
try {
for (;;) {
//判断retries是否等于RETRIES_BEFORE_LOCK(值为2)
//也就是默认有两次的机会,是不加锁来求size的
if (retries++ == RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
ensureSegment(j).lock(); // force creation
}
sum = 0L;
size = 0;
overflow = false;
//遍历Segments[]数组获取里面的每一个segment,然后对modCount进行求和
//这个for嵌套在for(;;)中,默认会执行两次,如果两次值相同,就返回
//如果两次值不同,就进入到上面的if中,进行加锁。之后在进行求和
for (int j = 0; j < segments.length; ++j) {
Segment<K,V> seg = segmentAt(segments, j);
if (seg != null) {
sum += seg.modCount;
int c = seg.count;
if (c < 0 || (size += c) < 0)
overflow = true;
}
}
if (sum == last)
break;
last = sum;
}
}
finally {
if (retries > RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
segmentAt(segments, j).unlock();
}
}
return overflow ? Integer.MAX_VALUE : size;
}