ThreadLocal源码分析

1. 线性探测法

ThreadLocal的作用,简而言之,就是在多线程环境下,有些数据会被共享,ThreadLocal可以实现将共享数据的访问限制在当前线程中。
每个Thread中都有一个ThreadLocal.ThreadLocalMap的字段,ThreadLocalMap是基于线性探测法的散列表,它的键是ThreadLocal类型,值是Object类型,也就是说,一个Thread对应多个ThreadLocal,即一个线程中可创建多个ThreadLocal对象。
下面对线性探测法实现的散列表进行介绍:
(1)数据结构

public class LinerProbingHashST {
    private Entry table[];
    static class Entry{
        Object key;
        Object value;
        public Entry(Object key, Object value) {
            this.key = key;
            this.value = value;
        }
    }
    // 增、删、改、查、扩容下面会进行介绍
}

(2)增

添加条目的轨迹示例:



示例说明:

  • 读示例时逐行来看
  • 键、值都为黑的是新添加的条目
  • 键为黑、值为灰的是探测经过的轨迹
  • 键、值都为灰的表示未被访问到
  • 键为黑、值为红色的表示值被替换
  • 为了便于描述,后面会将一组连续的条目叫做键簇,如示例中的A~L就是一个键簇

基于线性探测法的散列表中,条目的个数要小于散列表的大小(即要小于数组table的大小),当条目个数达到某个值时会对散列表进行扩容(即对数组进行扩容),并重新计算哈希码,确定索引,将已有条目转移到新数组中。
(3)删
以(2)中已添加的条目为例,若要删除键为H的条目(H的散列值为4),会从索引为4(即键为A的条目)处开始向后探测,找到H后,将H处的条目删除,并继续将该键簇中H之后的条目向前移动(这里只有键为L的条目会向前移动)。
(4)查
以(2)中已添加的条目为例,若要查找键为P的条目,P的散列值为14,因此会从索引14(即键为R的条目)处开始向后探测,注意,探测到索引15处后会折回到0处继续探测,直到找到P或遇到null为止。

2. ThreadLocal源码分析

2.1 索引相关
    // 后面会通过threadLocalHashCode & (INITIAL_CAPACITY - 1)来计算索引
    // 注意nextHashCode方法只会调用一次,用来初始化threadLocalHashCode 
    private final int threadLocalHashCode = nextHashCode();
  
    // static类型,所有ThreadLocal都用这一个AtomicInteger 
    // AtomicInteger可实现原子性累加
    private static AtomicInteger nextHashCode = new AtomicInteger();

    // 0x61c88647与斐波那契数列有关,实践证明,
    // 通过0x61c88647的配合计算出的索引分布很均匀
    private static final int HASH_INCREMENT = 0x61c88647;

    // 计算下一个哈希码 
    private static int nextHashCode() {
        // HASH_INCREMENT是增量,getAndAdd中会用AtomicInteger.value+HASH_INCREMENT
        // 作为新值,并返回AtomicInteger.value(不过考虑了多线程,用的CAS+不断重试)
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
2.2 ThreadLocal.ThreadLocalMap

ThreadLocalMap是基于线性探测法的哈希表,ThreadLocal的get、set等方法都会调用ThreadLocalMap的相关方法。ThreadLocalMap是ThreadLocal的核心,下面进行介绍:

2.2.1 ThreadLocal.ThreadLocalMap.Entry
    // Entry继承自WeakReference
    // Entry.k是弱引用,即Entry.k会在合适的时机被GC自动回收,若Entry.get返回null,
    // 表示Entry.k已被回收,ThreadLocalMap的相关方法中会以此作为依据进行相应的处理
    static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;
        // 注意k是ThreadLocal类型
        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }       
2.2.2 ThreadLocalMap中的字段
    // 初始容量
    private static final int INITIAL_CAPACITY = 16;

    // 底层数组,大小必须是2的幂(原因与HashMap中的一样,不再介绍) 
    private Entry[] table;

    // 数组中的元素个数
    private int size = 0;

    // 阈值 
    private int threshold;
2.2.3 ThreadLocalMap中的辅助方法

(1)setThreshold

    // 根据len设置阈值 
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

(2)nextIndex和prevIndex

    // 取后一个索引
    // 该方法存在的目的就是使得当索引为i+1为len时,折回到0
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }

    // 取前一个索引
    // 该方法存在的目的就是使得当索引为i-1为-1时,折回到len-1 
    private static int prevIndex(int i, int len) {
        return ((i - 1 >= 0) ? i - 1 : len - 1);
    }

(3)expungeStaleEntry

    // 与线性探测法的删除方法类似
    private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;

        // 将过期条目的值、条目本身置为null(键已经为null了)
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;
    
        Entry e;
        int i;
        // 对键簇中索引在staleSlot之后的元素进行移动
        for (i = nextIndex(staleSlot, len);
             (e = tab[i]) != null;
             i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();
            if (k == null) { // 该条目的k被回收,将该条目也删除
                e.value = null;
                tab[i] = null;
                size--;
            } else {
                int h = k.threadLocalHashCode & (len - 1);
                if (h != i) { // 不相等时才进行移动
                    tab[i] = null; // 将原来的条目置为null(已经暂存在e中了)        
                    // 向后探测,寻找为null的位置来存放该条目                        
                    while (tab[h] != null) 
                        h = nextIndex(h, len);
                    // 跳出while说明找到了合适的位置,将e存放到tab[h]处
                    tab[h] = e;
                }
            }
        }
        return i; // 注意i是上面构造完后,最后一个条目的下一个位置的索引(该位置为null)
    }

(4)cleanSomeSlots

    // 尝试性清除一些过期条目
    // 若不清除会存在"垃圾",若线性清除又消耗性能,
    // cleanSomeSlots算是一种权衡,每次只清除一部分
    private boolean cleanSomeSlots(int i, int n) {
        boolean removed = false;
        Entry[] tab = table;
        int len = tab.length;
        do {
            i = nextIndex(i, len);
            Entry e = tab[i];
            // e.get()返回null说明该e的k已经被回收了
            if (e != null && e.get() == null) {
                n = len; // 重置n
                removed = true;
                // 删除过期条目并更新i
                i = expungeStaleEntry(i);
            }
        } while ( (n >>>= 1) != 0);
        return removed; // 返回值会作为是否rehash的一个条件
    }

2.2.4 构造器
    // 注意ThreadLocalMap的键是ThreadLocal类型
    ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        table = new Entry[INITIAL_CAPACITY]; // 初始化table
        // 计算firstKey对应的索引
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        // 将第一个条目存入table[i]处
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        // 设置阈值
        setThreshold(INITIAL_CAPACITY);
    }
2.2.5 添加相关

(1)set

    // 与线性探测法的添加方法类似 
    private void set(ThreadLocal<?> key, Object value) {

        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
        
        // 遍历检测key对应的条目是否已存在或是否存在过期键
        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();

            // key对应的条目存在
            if (k == key) {
                e.value = value;
                return;
            }

            // 找到过期条目
            if (k == null) {
                // 用key对应的条目替换过期条目
                // 注意 i 是k对应的过期条目的索引
                replaceStaleEntry(key, value, i);
                return;
            }
        }
        // 到这里说明key对应的条目不存在且没有过期条目
        // 新建条目并保存到键簇末尾
        tab[i] = new Entry(key, value);
        int sz = ++size;
        // 因为是添加条目,所以这里需判断是否进行rehash
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }

(2)replaceStaleEntry

        // 用key对应的条目替换当前键簇中的过期条目 
        // 方法中会不断更新slotToExpunge的值,目的很简单,
        // 就是想找到整个键簇中(注意是整个)的第一个过期条目
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            int slotToExpunge = staleSlot;
            // 向staleSlot之前探测,寻找该键簇中,staleSlot之前是否存在过期条目
            // 注意:若staleSlot之前存在多个,slotToExpunge记录的是最前面过期条目的索引
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                // staleSlot之前存在过期条目
                if (e.get() == null)
                    // 记录过期条目的索引
                    slotToExpunge = i;

            // 向staleSlot之后探测,若存在key对应的条目,会进行替换,
            // 期间也会更新slotToExpunge,记录键簇中第一个过期条目的位置
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                // 存在key对应的条目
                if (k == key) {
                    // 更新value
                    e.value = value;
                    // 将tab[staleSlot]和tab[i]互换位置
                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;
                    // 该条件成立说明staleSlot之前没有过期条目
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i; // 更新slotToExpunge
                    // 清理过期条目
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }
                // k对应的条目条目过期 且 staleSlot之前没有过期条目
                if (k == null && slotToExpunge == staleSlot)
                    // 更新slotToExpunge(仅更新一次,后面slotToExpunge == staleSlot必定为false)
                    slotToExpunge = i; 
            }
            // 到这里说明键簇中不存在key对应的条目,新建一个,并存到tab[staleSlot]处
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);
            // 到这里若slotToExpunge与staleSlot相等,说明整个键簇中除了
            // staleSlot处(这里要存放key对应的条目)压根没有过期条目
            if (slotToExpunge != staleSlot)
                // 清理过期条目
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

假设ThreadLocalMap中的键为String类型,下面利用1.(2)中已添加的条目,举几个例子,对replaceStaleEntry中的所有情况进行介绍:
第1类:条目存在时的一些例子

  • 例1

    假设ThreadLocalMap中条目的初始状态如下:
    以添加(L,13)为例(注意L的散列值是6),替换后条目的状态如下(此时staleSlot为6、slotToExpunge为8):
  • 例2

    假设ThreadLocalMap中条目的初始状态如下:
    以添加(L,13)为例(注意L的散列值是6),替换后条目的状态如下(此时staleSlot为6、slotToExpunge为5):
  • 例3

    假设ThreadLocalMap中条目的初始状态如下:
    以添加(L,13)为例(注意L的散列值是6),替换后条目的状态如下(此时staleSlot为6、slotToExpunge为7):

第2类:条目不存在时的一些例子

  • 例4

    假设ThreadLocalMap中条目的初始状态如下:
    以添加(Z,13)为例(假设Z的散列值是6),替换后条目的状态如下(此时staleSlot为6、slotToExpunge为6):
  • 例5

    假设ThreadLocalMap中条目的初始状态如下:
    以添加(Z,13)为例(假设Z的散列值是6),替换后条目的状态如下(此时staleSlot为6、slotToExpunge为5):
  • 例6

    假设ThreadLocalMap中条目的初始状态如下:
    以添加(Z,13)为例(假设Z的散列值是6),替换后条目的状态如下(此时staleSlot为6、slotToExpunge为7):

    (3)rehash

    private void rehash() {
        // 删除所有过期条目
        expungeStaleEntries();

        // 条目个数超过某个值就进行扩容
        if (size >= threshold - threshold / 4)
            resize();
    }

(4)expungeStaleEntries

    // cleanSomeSlots是删除一部分过期条目,
    // 而expungeStaleEntries是删除所有过期条目 
    private void expungeStaleEntries() {
        Entry[] tab = table;
        int len = tab.length;
        // 遍历数组
        for (int j = 0; j < len; j++) {
            Entry e = tab[j];
            // 删除过期条目
            if (e != null && e.get() == null)
                expungeStaleEntry(j);
        }
    }

(5)resize

    // 将数组大小扩大为原来的2倍
    private void resize() {
        Entry[] oldTab = table;
        int oldLen = oldTab.length;
        int newLen = oldLen * 2; // 扩大2倍
        // 创建新数组
        Entry[] newTab = new Entry[newLen];
        int count = 0;
        
        // 将条目从oldTab转移到newTab
        for (int j = 0; j < oldLen; ++j) {
            Entry e = oldTab[j];
            if (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == null) { // 条目过期
                    e.value = null; // Help the GC
                } else {
                    // 计算新索引
                    int h = k.threadLocalHashCode & (newLen - 1);
                    // 向后探测,寻找第一个为null的位置
                    while (newTab[h] != null)
                        h = nextIndex(h, newLen);
                    newTab[h] = e; // 找到合适位置,进行存放
                    count++; // 记录条目个数
                }
            }
        }
        // 设置新阈值
        setThreshold(newLen);
        // 更新size
        size = count;
        table = newTab;
    }
2.2.6 查询相关

(1)getEntry

    // 获取key对应的条目 
    private Entry getEntry(ThreadLocal<?> key) {
        // 计算key对应的索引
        int i = key.threadLocalHashCode & (table.length - 1);
        // 尝试获取
        Entry e = table[i];
        // 因为Entry的k是弱引用,所以虽然e不为null,但e的键可能已经被回收,
        // 所以这里第二个条件为false可能是键不相等,也可能e的键已经被回收,e.get()返回null
        if (e != null && e.get() == key)
            return e;
        else 
            // 处理 e != null && e.get() != key的情况
            return getEntryAfterMiss(key, i, e);
    }

(2)getEntryAfterMiss

    // 与线性探测法的查找方法类似,不过考虑了k被回收的情况 
    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
        Entry[] tab = table;
        int len = tab.length;
        
        // 向后探测,直到遇到null
        while (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == key) // 键相等,直接返回
                return e;
            if (k == null) // k被回收
                expungeStaleEntry(i);
            else
                // 获取下一个索引
                i = nextIndex(i, len);
            // 取下一个条目   
            e = tab[i];
        }
        return null;
    }
2.2.7 移除
    // 删除key对应的条目 
    private void remove(ThreadLocal<?> key) {
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);
        for (Entry e = tab[i];
             e != null;
             e = tab[i = nextIndex(i, len)]) {
            if (e.get() == key) { // 定位到对应的键
                e.clear(); // 将e.k置为null
                expungeStaleEntry(i); // 清除该条目
                return;
            }
        }
    }
2.3 构造器
    public ThreadLocal() {
    }
2.4 ThreadLocal中的辅助方法

(1)getMap

    // 获取Thread.threadLocals 
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

(2)createMap

    // 创建ThreadLocalMap对象
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
2.5 查询相关

(1)get

    public T get() {
        // 获取当前线程对象
        Thread t = Thread.currentThread();
        // 获取当前线程对象对应的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            // 获取键为this(即当前调用get方法的TreadLocal)的条目
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                // 获取条目的值
                T result = (T)e.value;
                return result;
            }
        }
        // map为null 或 e为null,进行初始化
        return setInitialValue();
    }

(2)setInitialValue

    private T setInitialValue() {
        // 获取初始值(initialValue是空实现,交由子类重写)
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            // 调用ThreadLocalMap.set添加新条目
            map.set(this, value);
        else
            // 创建ThreadLocalMap对象并添加键为t、值为value的新条目
            createMap(t, value);
        return value;
    }
2.6 添加条目
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            // 调用ThreadLocalMap.set,添加或更新键为this、值为value的条目
            map.set(this, value);
        else
            // 创建ThreadLocalMap对象并添加键为t、值为value的新条目
            createMap(t, value);
    }
2.7 移除条目
    // 移除条目 
    public void remove() {
        ThreadLocalMap m = getMap(Thread.currentThread());
        if (m != null)
            // 调用ThreadLocalMap.remove,移除键为this的条目
            m.remove(this);
    }

3. MyBatis中的ThreadLocal案例

这里仅演示ThreadLocal的使用形式,不介绍具体的业务逻辑:

3.1 ErrorContext 中的ThreadLocal
public class ErrorContext {

  // ThreadLocal变量,用来将一个ErrorContext与当前线程绑定
  private static final ThreadLocal<ErrorContext> LOCAL = new ThreadLocal<>();
  
  public static ErrorContext instance() {
    // 通过ThreadLocal.get获取与当前线程绑定的ErrorContext
    ErrorContext context = LOCAL.get();
    if (context == null) {
      context = new ErrorContext();
      // 通过ThreadLocal.set将context与当前线程绑定
      LOCAL.set(context);
    }
    return context;
  }
  
  // 其余(略)
  
} 
3.2 SqlSessionManager中的ThreadLocal
public class SqlSessionManager implements SqlSessionFactory, SqlSession {

  // ThreadLocal变量,用来将一个SqlSession与当前线程绑定
  private final ThreadLocal<SqlSession> localSqlSession = new ThreadLocal<>();

  public void startManagedSession() {
    // 通过ThreadLocal.set将SqlSession与当前线程绑定
    this.localSqlSession.set(openSession());
  }
  
  public Connection getConnection() {
    // 通过ThreadLocal.get获取与当前线程绑定的SqlSession
    final SqlSession sqlSession = localSqlSession.get();
    if (sqlSession == null) {
      throw new SqlSessionException("Error:  Cannot get connection.  No managed session is started.");
    }
    return sqlSession.getConnection();
  }
  
  // 其余(略)
  
}  

4. ThreadLocal的内存泄露问题

假设创建了一个名为tl的ThreadLocal对象,当tl为null时,就无法调用ThreadLocal.remove方法将ThreadLocalMap中键为tl的条目进行删除,从而发生了内存泄漏。可以类比HashMap,若HaspMap对象中键为null,通过HashMap对象仍然可以操作键为null的条目,而ThreadLocal这种组织形式就做不到。
解决方案是ThreadLocalMap中的键为弱引用,因此键会被自动回收,任何线程在调用ThreadLocal的set/get/remove方法时,都会对过期条目(通过键是否为null来判断键对应的条目是否过期)进行清理,从而减少了内存泄漏。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容