Bootstrap

Android:ThreadLocal的简单理解和使用

1、背景

对于ThreadLocal,日常开发中一般有两种使用场景:

  • 每个线程需要一个独享的对象:比如Android中的Looper,后端中常用的工具类(如SimpleDateFormat)
  • 每个线程内需要保存全局变量:都知道Java服务端Controller作为接口响应入口,Service处理业务逻辑,Repository提供数据库CRUD数据接口,类似在拦截器中获取的用户信息这类共享数据,就可以放置到ThreadLocal中,就不用一层一层的通过参数传递下去。

1.1、背景及问题

一般来说,当某些数据是以线程为作用域并且不同线程具有不同的数据副本的时候,就可以考虑采用ThreadLocal

  • 比如对于Handler来说,它需要获取当前线程的Looper,很显然Looper的作用域就是线程并且不同线程具有不同的Looper,这个时候通过ThreadLocal就可以轻松实现Looper在线程中的存。
    • 如果不采用ThreadLocal,那么系统就必须提供一个全局的哈希表供Handler查找指定线程的Looper,这样一来就必须提供一个类似于LooperManager的类了,但是系统并没有这么做而是选择了ThreadLocal,这就是ThreadLocal的好处。
  • ThreadLocal另一个使用场景是复杂逻辑下的对象传递,比如监听器的传递,有些时候一个线程中的任务过于复杂,这可能表现为函数调用栈比较深以及代码入口的多样性,在这种情况下,我们又需要监听器能够贯穿整个线程的执行过程,这个时候可以怎么做呢?其实这时就可以采用ThreadLocal,采用ThreadLocal可以让监听器作为线程内的全局对象而存在,在线程内部只要通过get方法就可以获取到监听器。
    • 如果不采用ThreadLocal,那么我们能想到的可能是如下两种方法:
      • 第一种方法是将监听器通过参数的形式在函数调用栈中进行传递,
      • 第二种方法就是将监听器作为静态变量供线程访问。

上述这两种方法都是有局限性的。

  • 第一种方法的问题是当函数调用栈很深的时候,通过函数参数来传递监听器对象这几乎是不可接受的,这会让程序的设计看起来很糟糕。
  • 第二种方法是可以接受的,但是这种状态是不具有可扩充性的,比如同时有两个线程在执行,那么就需要提供两个静态的监听器对象,如果有10个线程在并发执行呢?提供10个静态的监听器对象?这显然是不可思议的,而采用ThreadLocal,每个监听器对象都在自己的线程内部存储,根本就不会有方法2的这种问题。

1.2、每个线程需要一个独享的对象

对于拿到时间戳,我们通常需要通过SimpleDateFormat类来将其转换成相应的日期格式,假设我们有如下一个工具类:

public class DateUtils {

    public static String format(long milliSeconds) {
      	SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss");
        return dateFormat.format(new Date(milliSeconds));
    }
}

现在我们通过线程池来模拟多线程环境:

public class ThreadLocalTest2 {

    private static ExecutorService threadPool = Executors.newFixedThreadPool(5);

    public static void main(String[] args) {

        for (int i = 0; i < 10; i++) {
            int finalI = i;
            threadPool.submit(() -> {
                String result = DateUtils.format(finalI * 1000);
                System.out.println(result);
            });
        }

        threadPool.shutdown();
    }

}

运行后的输出结果如下:

1970-01-01 08:00:03
1970-01-01 08:00:00
1970-01-01 08:00:02
1970-01-01 08:00:04
1970-01-01 08:00:01
1970-01-01 08:00:05
1970-01-01 08:00:08
1970-01-01 08:00:06
1970-01-01 08:00:09
1970-01-01 08:00:07

现在一切都是正常的,但是由于每次调用format方法都是创建一个新的SimpleDateFormat对象,这样是没有必要的。我们可以有如下修改:

public class DateUtils {

    private static SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss");

    public static String format(long milliSeconds) {
        return dateFormat.format(new Date(milliSeconds));
    }
}

现在再运行代码:

1970-01-01 08:00:02
1970-01-01 08:00:02
1970-01-01 08:00:02
1970-01-01 08:00:02
1970-01-01 08:00:07
1970-01-01 08:00:02
1970-01-01 08:00:09
1970-01-01 08:00:09
1970-01-01 08:00:07
1970-01-01 08:00:07

从结果来看,明显这种写法已经出问题了,数据重复,并发问题。那么该怎么去解决这个问题呢?接下来,就轮到我们今天的主人公ThreadLocal登场啦!

class DateUtils {

    private static ThreadLocal<SimpleDateFormat> threadLocal = ThreadLocal.withInitial(() -> new SimpleDateFormat("yyyy-MM-dd hh:mm:ss"));

    public static String format(long milliSeconds) {
        return threadLocal.get().format(new Date(milliSeconds));
    }
}

现在再运行:

1970-01-01 08:00:00
1970-01-01 08:00:01
1970-01-01 08:00:03
1970-01-01 08:00:05
1970-01-01 08:00:06
1970-01-01 08:00:04
1970-01-01 08:00:09
1970-01-01 08:00:02
1970-01-01 08:00:07
1970-01-01 08:00:08

这样,每个线程之间就互不干扰啦,因为每个进入format()方法的线程所使用的的SimpleDateFormat对象都是线程独享的,相互之间互不干扰的。

1.2、每个线程需要一个独享的对象

假定我们有一个UserInfo类,用来表示用户的信息:

class UserInfo {
    int id;

    public UserInfo(int id) {
        this.id = id;
    }
}

再有一个UseInfoHolder类,持有ThreadLocal对象:

class UserInfoHolder {

    static final ThreadLocal<UserInfo> holder = new ThreadLocal<>();
}

构造三个Service,分别表示处理逻辑:

class Service1 {


    public void process() {
        UserInfo userInfo = new UserInfo(1);
        UserInfoHolder.holder.set(userInfo);
        new Service2().process();
    }
}

class Service2 {

    public void process() {
        System.out.println("in Service2 : " + UserInfoHolder.holder.get().id);
        new Service3().process();
    }
}

class Service3 {

    public void process() {
        System.out.println("in Service3 : " + UserInfoHolder.holder.get().id);
    }
}

在Service1中,我们为UserInfoHolder中的ThreadLocal设置了值;在Service2、Service3中,我们可以直接通过UserInfoHolder中的ThreadLocal获取设置的UserInfo对象,从而做到共享。

最后写上main测试方法:

public class ThreadLocalTest3 {
    public static void main(String[] args) {
        new Service1().process();
    } 
}

运行结果如下:

in Service2 : 1
in Service3 : 1

2、ThreadLocal原理

我们可以用下面的图来表示Thread、ThreadLocal以及ThreadLocalMap之间的关系:
在这里插入图片描述在上图中我们可以发现,整个ThreadLocal的使用都涉及到线程中ThreadLocalMap,虽然我们在外部调用的是ThreadLocal.set(value)方法,

但本质是通过线程中的ThreadLocalMap中的set(key,value)方法,那么通过该情况我们大致也能猜出get方法也是通过ThreadLocalMap。那么接下来我们一起来看看ThreadLocal中set与get方法的具体实现与ThreadLocalMap的具体结构。

2.1、使用说明

ThreadLocal提供线程局部变量。这些变量不同于它们的正常变量,即每一个线程访问自身的局部变量时,都有它自己的,独立初始化的副本。该变量通常是与线程关联的私有静态字段,列如用于ID或事物ID。大家看了介绍后,有可能还是不了解其主要的主要作用,简单的画个图帮助大家理解。
在这里插入图片描述
从图上可以看出,通过ThreadLocal,每个线程都能获取自己线程内部的私有变量,下面我们通过具体的例子详细的介绍,来看下面的代码。

class ThreadLocalTest {
	//会出现内存泄漏的问题,下文会描述
    private static ThreadLocal<String> mThreadLocal = new ThreadLocal<>();

    public static void main(String[] args) {
        mThreadLocal.set("线程main");
        new Thread(new A()).start();
        new Thread(new B()).start();
        System.out.println(mThreadLocal.get());
    }

    static class A implements Runnable {

        @Override
        public void run() {
            mThreadLocal.set("线程A");
            System.out.println(mThreadLocal.get());
        }
    }

    static class B implements Runnable {

        @Override
        public void run() {
            mThreadLocal.set("线程B");
            System.out.println(mThreadLocal.get());
        }
    }
}

在上诉代码中,我们在主线程中设置mThreadLocal的值为"线程main",在线程A中设置为”线程A“,在线程B中设置为”线程B",运行程序打印结果如下所示:

main
线程A
线程B

从上面结果可以看出,虽然是在不同的线程中访问的同一个变量mThreadLocal,但是他们通过ThreadLocl获取到的值却是不一样的。也就验证了上面我们所画的图是正确的了,那么现在,我们已经知道了ThreadLocal的用法,那么我们现在来看看其中的内部原理。

2.2、ThreadLocal的set方法

public void set(T value) {
        Thread t = Thread.currentThread();//获取当前线程
        ThreadLocalMap map = getMap(t);//拿到线程的LocalMap
        if (map != null)
            map.set(this, value);//设值 key->当前ThreadLocal对象。value->为当前赋的值
        else
            createMap(t, value);//创建新的ThreadLocalMap并设值
    }

当调用set(T value) 方法时,方法内部会获取当前线程中的ThreadLocalMap,获取后进行判断,

  • 如果不为空,就调用ThreadLocalMapset方法(其中key为当前ThreadLocal对象,value为当前赋的值)。
  • 反之,让当前线程创建新的ThreadLocalMap并设值,其中getMap()与createMap()方法具体代码如下:
ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    
void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

ThreadLocal中所有的数据操作都与线程中的ThreadLocalMap有关,同时那我们接下来看看ThreadLocalMap相关代码

2.3、ThreadLocalMap 内部结构

在这里插入图片描述
ThreadLocalMap是ThreadLocal中的一个静态内部类,官方的注释写的很全面,这里我大概的翻译了一下,ThreadLocalMap是为了维护线程私有值创建的自定义哈希映射。其中线程的私有数据都是非常大且使用寿命长的数据

ThreadLocalMap 具体代码如下:

static class ThreadLocalMap {
		//存储的数据为Entry,且key为弱引用
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
        //table初始容量
        private static final int INITIAL_CAPACITY = 16;
      
        //table 用于存储数据
        private Entry[] table;
        
	    //负载因子,用于数组容量扩容
        private int threshold; // Default to 0
        
		//负载因子,默认情况下为当前数组长度的2/3
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }
	    //第一次放入Entry数据时,初始化数组长度,定义扩容阀值,
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];//初始化数组长度为16
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);//阀值为当前数组默认长度的2/3
        }

从代码中可以看出,虽然官方申明为ThreadLocalMap是一个哈希表,但是它与我们传统认识的HashMap等哈希表内部结构是不一样的。

ThreadLocalMap内部仅仅维护了Entry[] table,数组。其中Entry实体中对应的key为弱引用(下文会将为什么会用弱引用),在第一次放入数据时,会初始化数组长度(为16),定义数组扩容阀值(当前默认数组长度的2/3)。

2.3.1、ThreadLocalMap 的set()方法

private void set(ThreadLocal<?> key, Object value) {

		    //根据哈希值计算位置
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            
            //判断当前位置是否有数据,如果key值相同,就替换,如果不同则找空位放数据。
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {//获取下一个位置的数据
                ThreadLocal<?> k = e.get();
			//判断key值相同否,如果是直接覆盖 (第一种情况)
                if (k == key) {
                    e.value = value;
                    return;
                }
			//如果当前Entry对象对应Key值为null,则清空所有Key为null的数据(第二种情况)
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //以上情况都不满足,直接添加(第三种情况)
            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)//如果当前数组到达阀值,那么就进行扩容。
                rehash();
        }

通过代码分析可知,ThreadLocalMap的set函数主要分为三个主要步骤:

  • 1、计算出当前ThreadLocal在table数组的位置,然后向后遍历,直到遍历到的Entry为null则停止,遍历到Entry的key与当前threadLocal实例的相等,直接更替value;
  • 2、如果遍历到Entry已过期(Entry的key为null),则调用replaceStaleEntry函数进行替换。
  • 3、在遍历结束后,未出现1和2两种情况,则直接创建新的Entry,保存到数组最后侧没有Entry的位置。
2.3.1.1、第一种情况, Key值相同

如果当前数组中,如果当前位置对应的Entry的key值与新添加的Entry的key值相同,直接进行覆盖操作。具体情况如下图所示
在这里插入图片描述
如果当前数组中。存在key值相同的情况,ThreadLocal内部操作是直接覆盖的。

2.3.1.2、第二种情况,如果当前位置对应Entry的Key值为null

在这里插入图片描述
从图中我们可以看出来。当我们添加新Entry(key=19,value =200,index = 3)时,数组中已经存在旧Entry(key =null,value = 19),

当出现这种情况是,方法内部会将新Entry的值全部赋值到旧Entry中,同时会将所有数组中key为null的Entry全部置为null(图中大黄色数据)。

在源码中,当新Entry对应位置存在数据,且key为null的情况下,会走replaceStaleEntry方法。具体代码如下:

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

	        //记录当前要清除的位置
            int slotToExpunge = staleSlot;
            
            //往前找,找到第一个过期的Entry(key为空)
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)//判断引用是否为空,如果为空,擦除的位置为第一个过期的Entry的位置
                    slotToExpunge = i;

		    //往后找,找到最后一个过期的Entry(key为空),
            for (int i = nextIndex(staleSlot, len);//这里要注意获得位置有可能为0,
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                //在往后找的时候,如果获取key值相同的。那么就重新赋值。
                if (k == key) {
                	//赋值到之前传入的staleSlot对应的位置
                    e.value = value;
                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    //如果往前找的时候,没有过期的Entry,那么就记录当前的位置(往后找相同key的位置)
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                        
                    //那么就清除slotToExpunge位置下所有key为null的数据
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

			    //如果往前找的时候,没有过期的Entry,且key =null那么就记录当前的位置(往后找key==null位置)
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // 把当前key为null的对应的数据置为null,并创建新的Entry在该位置上
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            //如果往后找,没有过期的实体, 
            //且staleSlot之前能找到第一个过期的Entry(key为空),
            //那么就清除slotToExpunge位置下所有key为null的数据
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

replaceStaleEntry函数,主要分为两次遍历,以当前过期的Entry为分割线,一次向前遍历,一次向后遍历。

主要对四种情况进行了判断,具体情况如下图表所示:
在这里插入图片描述replaceStaleEntry方法内部会清除key==null的数据,而其中具体的方法与expungeStaleEntry()方法与cleanSomeSlots()方法有关。

2.3.1.3、第三种情况,当前对应位置为null

在这里插入图片描述图上为了方便大家,理解清空上下数据的情况,我并没有重新计算位置(希望大家注意!!!)

看到这里,为了方便大家避免不必要的查阅代码,我直接将代码贴出来了。代码如下。

tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();

清除key==null的数据,判断当前数据的长度是不是到达阀值(默认没扩容前为INITIAL_CAPACITY *2/3,其中INITIAL_CAPACITY = 16),如果达到了重新计算数据的位置。关于rehash()方法,具体代码如下:

private void rehash() {
         expungeStaleEntries();

         // Use lower threshold for doubling to avoid hysteresis
         if (size >= threshold - threshold / 4)
                resize();
        }
        
 //清空所有key==null的数据
 private void expungeStaleEntries() {
         Entry[] tab = table;
         int len = tab.length;
         for (int j = 0; j < len; j++) {
             Entry e = tab[j];
             if (e != null && e.get() == null)
                 expungeStaleEntry(j);
            }
        }
 //重新计算key!=null的数据。新的数组长度为之前的两倍      
 private void resize() {
			//对原数组进行扩容,容量为之前的两倍
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;
			//重新计算位置
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
			//重新计算阀值(负载因子)为扩容之后的数组长度的2/3
            setThreshold(newLen);
            size = count;
            table = newTab;
        }

可以看出在添加数据的时候,会进行判断是否扩容操作,如果需要扩容,会清除所有的key==null的数据,(也就是调用expungeStaleEntries()方法,同时会重新计算数据中的位置。

2.4、ThreadLocal的get()方法

 public T get() {
        Thread t = Thread.currentThread();//获取当前线程
        ThreadLocalMap map = getMap(t);//拿到线程中的Map
        if (map != null) {
            //根据key值(ThreadLocal)对象,获取存储的数据
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        //如果ThreadLocalMap为空,创建新的ThreadLocalMap 
        return setInitialValue();
    }

其实ThreadLocal的get方法其实很简单,就是获取当前线程中的ThreadLocalMap对象,如果没有则创建,如果有,则根据当前的 key(当前ThreadLocal对象),获取相应的数据。

其中内部调用了ThreadLocalMap的getEntry()方法区获取数据,我们继续查看getEntry()方法。

 private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

getEntry()方法内部也很简单,也只是根据当前key哈希后计算的位置,去找数组中对应位置是否有数据,如果有,直接将数据放回,如果没有,则调用getEntryAfterMiss()方法,我们继续往下看 。

 private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)//如果key相同,直接返回
                    return e;
                if (k == null)//如果key==null,清除当前位置下所有key=null的数据。
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;//没有数据直接返回null
        }

从上述代码我们可以知道,如果从数组中,获取的key==null的情况下,get方法内部也会调用expungeStaleEntry()方法,去清除当前位置所有key==null的数据

也就是说现在不管是调用ThreadLocal的set()还是get()方法,都会去清除key==null的数据。

3、ThreadLocal内存泄漏的问题

在Java中判断一个对象到底是不是需要回收,都跟引用相关。在Java中引用分为了4类。

  • 1、强引用:只要引用存在,垃圾回收器永远不会回收Object obj = new Object();而这样 obj对象对后面new Object的一个强引用,只有当obj这个引用被释放之后,对象才会被释放掉。
  • 2、软引用:是用来描述,一些还有但并非必须的对象,对于软引用关联着的对象,在系统将要发生内存溢出异常之前,将会把这些对象列进回收范围之中进行第二次回收。(SoftReference)
  • 3、弱引用:也是用来描述非必须的对象,但是它的强度要比软引用更弱一些。被弱引用关联的对象只能生存到下一次垃圾收集发生之前,当垃圾收集器工作是,无论当前内存是否足够,都会回收掉被弱引用关联的对象。(WeakReference)
  • 4、虚引用:也被称为幽灵引用,它是最弱的一种关系。一个对象是否有引用的存在,完全不会对其生存时间构成影响,也无法通过一个虚引用来取得一个实例对象。

3.1、为什么使用弱引用

如果key使用强引用,那么当引用ThreadLocal的对象被回收了,但ThreadLocalMap中还持有ThreadLocal的强引用,如果没有手动删除,ThreadLocal不会被回收,导致内存泄漏。

3.2、弱引用带来的问题

从上面我们已经知道了,ThreadLocalMap使用ThreadLocal的弱引用作为key,也就是说,如果一个ThreadLocal没有外部强引用来引用它,那么系统 GC 的时候,这个ThreadLocal势必会被回收。这样一来,ThreadLocalMap中就会出现key为nullEntry,就没有办法访问这些key为nullEntry的value

如果当前线程迟迟不结束的话,这些key为nullEntry的value就会一直存在一条强引用链:Thread Ref(当前线程引用) -> Thread -> ThreadLocalMap -> Entry -> value,那么将会导致这些Entry永远无法回收,造成内存泄漏。

不过,这一点设计者也考虑到了,在get()set()remove()方法调用的时候会清除掉线程ThreadLocalMap中所有EntryKey为null的Value,并将整个Entry设置为null,这样在下次回收时就能将Entry和value回收。

4、总结

  • 1、ThreadLocal本质是操作线程中ThreadLocalMap来实现本地线程变量的存储
  • 2、ThreadLocalMap是采用数组的方式来存储数据,其中key(弱引用)指向当前ThreadLocal对象,value为设的值
  • 3、ThreadLocal为内存泄漏采取了处理措施,在调用ThreadLocalget(),set(),remove()方法的时候都会清除线程ThreadLocalMap里所有key为null的Entry
  • 4、在使用ThreadLocal的时候,我们仍然需要注意,避免使用staticThreadLocal,分配使用了ThreadLocal后,一定要根据当前线程的生命周期来判断是否需要手动的去清理ThreadLocalMap中清key==null的Entry

参考

1、Android Handler机制之ThreadLocal
2、Android进阶:ThreadLocal
3、写给Android开发者的ThreadLocal介绍

;