ThreadLocal源码分析(JDK 1.8)

  " 要保证线程安全,也并非一定要进行阻塞或非阻塞同步,同步与线程安全两者没有必然的联系。同步只是保障存在共享数据争用时正确性的手段,如果能让一个方法本来就不涉及共享数据,那它自然就不需要任何同步措施去保证其正确性,因此会有一些代码天生就是线程安全的 " —— 《深入理解JVM虚拟机》。
  如果需要让一个变量只让一个线程独享,可以使用 ThreadLocal 实现线程本地存储的功能,使得每个线程都保存一份该变量的副本,每个线程都只对该副本进行操作,所以不会存在并发安全问题。

  使用方式如下:

ThreadLocal<String> threadLocal = new ThreadLocal<>();
threadLocal.set(str);
threadLocal.get();

  每个 Thread 对象都持有一个 ThreadLocalMap 对象,线程和变量副本的值的对应关系存储在 ThreadLocalMap 对象中的,ThreadLocalMap 是 ThreadLocal 的一个静态内部类,代码如下所示:

 static class ThreadLocalMap {
        // 每个 ThreadLocal 和绑定的value都会封装成 Entry 对象
        static class Entry extends WeakReference<ThreadLocal<?>> {     
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        private static final int INITIAL_CAPACITY = 16;
        // 存放线程中多个ThreadLocal和对应值的对应关系
        private Entry[] table;

        private int size = 0;

        private int threshold; 
}
  1. ThreadLocal.set()
      ThreadLocalMap将 ThreadLocal和对应的值封装成 Entry对象存放在 table数组中。
    // 将 ThreadLocal 和 value 的对应关系存储在 Thread对象的threadLocals属性中
    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }
    // 返回线程持有的 ThreadLocalMap 对象
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    void createMap(Thread t, T firstValue) {
        t.threadLocals  = new ThreadLocalMap(this, firstValue);
    }

   ThreadLocalMap 对象将ThreadLocal和绑定值的对应关系封装成 Entry 对象,存放在该对象的 table数组中,由此可以引出几个问题。
   (1)table数组何时初始化?
   (2)如何解决哈希冲突?
   (3)达到阀值如何扩容?
  
   在调用 ThreadLocal.set()方法时,如果线程对应的 threadLocals属性没有初始化,则会调用 createMap()方法初始化线程threadLocals的属性,代码如下所示:

        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            // 默认容量是 16 
            table = new Entry[INITIAL_CAPACITY];
            // 根据ThreadLocal的hashCode和数组容量计算需要Entry存放的位置
            // 因为容量是2的n次方,所以实际上是取模  
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

   如果已经该属性已经初始化,则将对应关系添加到数组中。

private void set(ThreadLocal<?> key, Object value) {

            Entry[] tab = table;
            int len = tab.length;
            //1. 计算存放的下标位置
            int i = key.threadLocalHashCode & (len-1);
            //2. 遍历数组,向后寻找合适的位置
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                 //2.1 获取该节点的key : ThreadLocal
                ThreadLocal<?> k = e.get();
                //2.1 如果该节点的key和新的key相等,则覆盖旧值
                if (k == key) {
                    e.value = value;
                    return;
                }
                //2.3 如果k为null,则添加新值,并去除脏值
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //3. 如果 i 出为null,则创建 Entry放在该处
            tab[i] = new Entry(key, value);
            int sz = ++size;
            //4. 清除脏值并且扩容
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

  当两个ThreadLocal实例的 hashCode恰好相同时,则会出现哈希冲突,ThreadLocalMap是通过线性探测法解决哈希冲突的。线性探测法就是查找散列表中离冲突单元最近的空闲单元,并且把新的键插入这个空闲单元。如上述代码的第二步中,通过遍历数组决定存放位置,如果出现了哈希冲突,则接着向后遍历寻找合适的位置,如果到达数组的末尾,将会从下标0开始,呈环形重新遍历。
  当数组的元素数量已经超过阀值时,将会引起扩容,代码如下:

       private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            //1. 二倍扩容
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;
            //2. 遍历旧数组,将元素放到新的位置
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    //2.1 如果该位置的 Entry对象对应的key为null,说明该ThreadLocal可能已经被垃圾回收出现了脏值
                    if (k == null) {
                        e.value = null; // Help the GC
                    //2.2 计算元素新位置,并使用线性探测法解决可能出现的冲突,然后将元素放置到新位置中    
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
            //3. 设置新的阀值,更新table指向和元素数量size
            setThreshold(newLen);
            size = count;
            table = newTab;
        }
  1. ThreadLocal.get()
      由于每个线程都持有一个 ThreadLocalMap 对象,所以当调用ThreadLocal.get()方法时,会获取当前线程持有的 ThreadLocalMap 对象,然后以当前 ThreadLocal 为key,从该对象中获取对应的 Entry对象,然后返回对应的值。
    public T get() {
        //1. 首先获取当前线程的 threadLocals 属性
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        //2. 然后使用当前的 ThreadLocal对象作为key从 ThreadLocalMap对象中寻找
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        //3. 如果该线程的 threadLocals 属性还没有初始化,则初始化该属性   
        return setInitialValue();
    }
    private T setInitialValue() {
        T value = null;
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
        ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

  通过 key 在 ThreadLocalMap中寻找 Entry 对象的过程如下所示:

        private Entry getEntry(ThreadLocal<?> key) {
            //1. 计算该key对应的下标
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            //2. 如果对应下标的 Entry不为空,且key相同,则返回该 Entry对象
            if (e != null && e.get() == key)
                return e;
            //3. 如果key不同,说明可能由于哈希冲突导致该key对应的 Entry对象放到了其他位置 
            else
                return getEntryAfterMiss(key, i, e);
        }
        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;
            //如果对应位置的 Entry为空,说明没有没有对应的值,直接返回null
            while (e != null) {
                ThreadLocal<?> k = e.get();
                //1. 如果找到对应的key,返回 entry 对象
                if (k == key)
                    return e;
                //2. 如果entry不为空,但是 key为空,则清除脏数据   
                if (k == null)
                    expungeStaleEntry(i);
                //3. 环形i++    
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }
  1. ThreadLocal.remove()
      该方法是将线程中该ThreadLocal绑定的值移除。
     public void remove() {
         //1. 获得当前线程绑定的ThreadLocalMap
         ThreadLocalMap m = getMap(Thread.currentThread());
         //2. 移除 ThreadLocal对应的节点
         if (m != null)
             m.remove(this);
     }
     
     private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            //1. 获得 key 对应的下标
            int i = key.threadLocalHashCode & (len-1);
            //2. 从i 开始遍历(哈希冲突导致存放位置修改)
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                //2.1 如果寻找到了对应的key,则将key置为null,然后开始清除脏值 
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }
  1. ThreadLocal内存泄露
      之前一直说脏值,那么脏值是如何产生的呢?我们都知道 ThreadLocal和值的对应关系会封装成 Entry对象,而该Entry对象继承了WeakReference类,当新建 Entry对象时,会将调用 super(k)方法,表明当前 ThreadLocal 对象是一个弱引用的对象,被弱引用关联的对象只能生存到下一次垃圾收集发生为止,所以当我们调用clear()方法或者手动将该 ThreadLocal 对象置为null时该ThreadLocal对象将会被回收,因此会存在对应 Entry 对象的key为空的情况。虽然该entry对象的key为空,但是还存在一条 thread-> value的引用链导致value无法被回收,而我们无法使用任何方式访问对应的value,导致内存泄露。
        static class Entry extends WeakReference<ThreadLocal<?>> {
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

  4.1 expungeStaleEntry()
  为了防止内存泄露,ThreadLocal 提供了一个方法用来清除脏值,该方法会遍历table 数组直到遇到 entry为null停止,在遍历过程中如果遇到了 Entry的key为空的情况,则有可能遇到脏数据,手动将对应的value置为空,如果当前位置的 Entry没有脏数据,就 rehash 重新计算存放位置。

        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            // 将该脏值对应的下标所在位置的value置为null
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;
            
            Entry e;
            int i;
            // 环形遍历去除脏值并且做 rehash操作
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                //1. 如果存在脏值,将value值为Null
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                //2. rehash操作    
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;

                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            // 返回 staleSlot下标 之后第一个为空的entry对象对应的数组下标
            return i;
        }

  4.2 cleanSomeSlots()
  在添加新元素时,会调用该方法尝试清除一些脏值。扫描过程是通过 n 控制扫描的长度。在扫描过程中,如果没有遇到脏数据会扫描 log2len次,如果遇到了脏数据会使 n = len,延长扫描的长度,所以该方法可能会造成添加元素的复杂度变为O(n)。

        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                //遇到了脏数据
                if (e != null && e.get() == null) {
                    //延长扫描长度
                    n = len;
                    removed = true;
                    //进行清除
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

  4.3 replaceStaleEntry()
  在添加元素时,如果在寻找新元素的存放位置的过程中扫描了脏值,则会调用该方法去除脏数据,并且添加新值。

        // key,value是新添加的键值对; staleSlot是出现脏数据的下标
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            //1. 向前循环遍历table直到遇到 null,slotToExpunge 为出现脏值的下标,如果没出现脏值则值为 staleSlot
            int slotToExpunge = staleSlot;
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            //2. 向后遍历table直到遇到entry为null
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                //2.1 如果找到了新元素所放位置,也就是存在覆盖的旧key 
                if (k == key) {
                    //覆盖旧值,并替换table数组中 i 和 staleSlot位置的entry对象
                    e.value = value;
                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;
                    // 如果这两个值相等,则表名向前没有扫描到脏数据,则从i开始扫描进行清除脏值
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }
                //2.2 如果 k 为null且是 staleSlot之后的第一个脏值,则更新slotToExpunge 的值
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            //3 如果遇到了null还是没有遇到存放位置,则将新元素存放到null的位置
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            //4 如果这两个值不相等,所以存在脏数据,则从slotToExpunge下标开始清除
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

  该方法主要的作用就是确定新键值对的位置和清除旧值时开始扫描的下标:
  (1)在添加元素时,计算新元素的下标之后,如果出现哈希冲突,则会向后环形遍历确定插入位置,如果遍历过程中出现了脏值(下标为staleSlot),则调用该方法。
  (2)首先确定扫描脏值的开始下标 slotToExpunge,该值就是第一步出现脏值的下标,然后向前遍历,如果遇到脏值,则更新slotToExpunge的值,直到遇到Null则遍历结束。(这里我觉得向前遍历的开始坐标可以使用key.threadLocalHashCode & (len-1)计算,因为该表达式的值到staleSlot下标中间的元素是不会出现脏值和null的)

第1.2步

  (3)然后向后遍历寻找可以覆盖的键值对(key相等)直到遇到 null。
  无论哪种情况目的都是为了清除脏值时减少扫描数组的长度,因为无论哪种情况staleSlot下标位置都会放入新元素构成的Entry对象,不是原来的脏值。如果扫描到了覆盖key,创建新Entry对象并和staleSlot位置的Entry对象交换;如果扫描不到,则在第四步创建新的Entry放到该staleSlot处,如下图所示
  a. 在扫描过程中如果遇到了可以覆盖的key,即 ThreadLocal对象相同,则覆盖旧值。然后交换 staleSlotSlotKey位置的值,此时staleSlot位置就是新添加的元素,脏值对应的entry会被放在在SlotKey位置上,如果第二步向前扫描没有扫描到脏值,说明staleSlot位置之前没有脏值,而staleSlot和SlotKey之间没有脏值,所以清除脏值的下标从当前SlotKey开始即可。
  b. 如果扫描过程中发现当前遍历的Key是脏值,而且是staleSlot之后的第一个脏值,而且staleSlot之前没有脏值,则更新slotToExpunge的值为该key,清除时从该key开始即可。比如下图中的SlotKeystaleSlot之后的第一个脏值,而第一步的向前扫描时没有扫描到脏值,所以可以保证新元素插入后SlotKey之前是没有脏值的,所以从该位置进行清除即可。

第三步

  (4)如果第三步找不到,说明当前的下标位置是个null值且该下标之前没有可以覆盖的key,则将创建新键值对 Entry对象放到该位置,然后从 slotToExpunge位置开始清除。

第四步
  1. InheritableThreadLocal
      ThreadLocal不支持继承性,也就是说,同一个ThreadLocal变量在父线程中被设置值后,在子线程中是获取不到的。
public class Main {
    public static final ThreadLocal<String> threadLocal = new ThreadLocal<>();
    public static final ThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

    public static void main(String[] args) {
        threadLocal.set("Test Inheritable");
        inheritableThreadLocal.set("Test Inheritable");
        new Thread(()->{
            System.out.println("ThreadLocal : " + threadLocal.get());
            System.out.println("InheritableThreadLocal : " + inheritableThreadLocal.get());
        }).start();
    }
}

ThreadLocal : null
InheritableThreadLocal : Test Inheritable

  该类支持子线程获取父线程设置到ThreadLocal的值,该类的代码如下所示,该类继承了ThreadLocal且重写了三个方法。

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    //注意:该类初始化的是线程的inheritableThreadLocals 属性而非 threadLocals
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

  接下来我们看下InheritableThreadLocal覆盖的方法何时调用,该功能是如何实现的。

    //1. 创建线程时调用的构造函数
    public Thread(Runnable target) {
        init(null, target, "Thread-" + nextThreadNum(), 0);
    }
     //2. 线程的初始化方法
     private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
        ......
        //1. 获取该线程对应的父线程
        Thread parent = currentThread();
        //2. 如果父线程的 inheritableThreadLocals  属性不为空,则复制父线程的属性
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        ......
    }

    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }
        //1. 将父线程中持有的ThreadLocalMap 对象复制一份到当前线程的inheritableThreadLocals 属性中
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            //1. 获取父线程所有的 Entry对象
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            //2. 初始化本线程的数组
            table = new Entry[len];
            //3. 遍历数组,进行复制
            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

  注意,子线程创建时只是复制父线程的 ThreadLocalMap中的属性到当前线程中,创建完成之后,父线程对ThreadLocalMap中的属性的修改对子线程是不可见的。
  "那么在什么情况下需要子线程可以获取父线程的threadlocal变量呢?情况还是蛮多的,比如子线程需要使用存放在threadlocal变量中的用户登录信息,再比如一些中间件需要把统一的id追踪的整个调用链路记录下来。其实子线程使用父线程中的threadlocal方法有多种方式,比如创建线程时传入父线程中的变量,并将其复制到子线程中,或者在父线程中构造一个map作为参数传递给子线程,但是这些都改变了我们的使用习惯,所以在这些情况下InheritableThreadLocal就显得比较有用" ——《java并发编程之美》

©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容