自己实现一个LRU cache

LRU 是 Least Recently Used 的简写,字面意思则是最近最少使用。

通常用于缓存的淘汰策略实现,由于缓存的内存非常宝贵,所以需要根据某种规则来剔除数据保证内存不被撑满。

如常用的 Redis 就有以下几种策略:


image.png

有一道面试题,大概需求是:

  • 实现一个 LRU 缓存,当缓存数据达到 N 之后需要淘汰掉最近最少使用的数据。
  • N 小时之内没有被访问的数据也需要淘汰掉。
package CacheTest;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Set;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Function:
 *
 * 1.在做 key 生成 hashcode 时是用的 HashMap 的 hash 函数
 * 2.在做 put get 时,如果存在 key 相等时候为了简单没有去比较 equal 和 hashcode
 * 3.限制大小, map的最大size是1024, 超过1024后,就淘汰掉最久没有访问的kv 键值对, 当淘汰时,需要调用一个callback   lruCallback(K key, V value)
 * 是利用每次 put 都将值写入一个内部队列,这样只需要判断队列里的第一个即可。
 * 4.具备超时功能, 当键值对1小时内没有被访问, 就被淘汰掉, 当淘汰时, 需要调用一个callback   timeoutCallback(K key, V value);
 * 超时同理,单独开启一个守护进程来处理,取的是队列里的第一个 因为第一个是最早放进去的。
 *
 * 但是像 HashMap 里的扩容,链表在超过阈值之类的没有考虑进来。
 *
 */
public class LRUAbstractMap extends java.util.AbstractMap {

    private final static Logger LOGGER = LoggerFactory.getLogger(LRUAbstractMap.class);

    /**
     * 检查是否超期线程
     */
    private ExecutorService checkTimePool ;

    /**
     * map 最大size
     */
    private final static int MAX_SIZE = 1024 ;

    private final static ArrayBlockingQueue<Node> QUEUE = new ArrayBlockingQueue<>(MAX_SIZE) ;

    /**
     * 默认大小
     */
    private final static int DEFAULT_ARRAY_SIZE =1024 ;


    /**
     * 数组长度
     */
    private int arraySize ;

    /**
     * 数组
     */
    private Object[] arrays ;


    /**
     * 判断是否停止 flag
     */
    private volatile boolean flag = true ;


    /**
     * 超时时间
     */
    private final static Long EXPIRE_TIME = 60 * 60 * 1000L ;

    /**
     * 整个 Map 的大小
     */
    private volatile AtomicInteger size  ;


    public LRUAbstractMap() {


        arraySize = DEFAULT_ARRAY_SIZE;
        arrays = new Object[arraySize] ;

        //开启一个线程检查最先放入队列的值是否超期
        executeCheckTime();
    }

    /**
     * 开启一个线程检查最先放入队列的值是否超期 设置为守护线程
     */
    private void executeCheckTime() {
        ThreadFactory namedThreadFactory = new ThreadFactoryBuilder()
                .setNameFormat("check-thread-%d")
                .setDaemon(true)
                .build();
        checkTimePool = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS,
                new ArrayBlockingQueue<>(1),namedThreadFactory,new ThreadPoolExecutor.AbortPolicy());
        checkTimePool.execute(new CheckTimeThread()) ;

    }

    @Override
    public Set<Entry> entrySet() {
        return super.keySet();
    }

    @Override
    public Object put(Object key, Object value) {
        int hash = hash(key);
        int index = hash % arraySize ;
        Node currentNode = (Node) arrays[index] ;

        if (currentNode == null){
            arrays[index] = new Node(null,null, key, value);

            //写入队列
            QUEUE.offer((Node) arrays[index]) ;

            sizeUp();
        }else {
            Node cNode = currentNode ;
            Node nNode = cNode ;

            //存在就覆盖
            if (nNode.key == key){
                cNode.val = value ;
            }

            while (nNode.next != null){
                //key 存在 就覆盖 简单判断
                if (nNode.key == key){
                    nNode.val = value ;
                    break ;
                }else {
                    //不存在就新增链表
                    sizeUp();
                    Node node = new Node(nNode,null,key,value) ;

                    //写入队列
                    QUEUE.offer(currentNode) ;

                    cNode.next = node ;
                }

                nNode = nNode.next ;
            }

        }

        return null ;
    }


    @Override
    public Object get(Object key) {

        int hash = hash(key) ;
        int index = hash % arraySize ;
        Node currentNode = (Node) arrays[index] ;

        if (currentNode == null){
            return null ;
        }
        if (currentNode.next == null){

            //更新时间
            currentNode.setUpdateTime(System.currentTimeMillis());

            //没有冲突
            return currentNode ;

        }

        Node nNode = currentNode ;
        while (nNode.next != null){

            if (nNode.key == key){

                //更新时间
                currentNode.setUpdateTime(System.currentTimeMillis());

                return nNode ;
            }

            nNode = nNode.next ;
        }

        return super.get(key);
    }


    @Override
    public Object remove(Object key) {

        int hash = hash(key) ;
        int index = hash % arraySize ;
        Node currentNode = (Node) arrays[index] ;

        if (currentNode == null){
            return null ;
        }

        if (currentNode.key == key){
            sizeDown();
            arrays[index] = null ;

            //移除队列
            QUEUE.poll();
            return currentNode ;
        }

        Node nNode = currentNode ;
        while (nNode.next != null){

            if (nNode.key == key){
                sizeDown();
                //在链表中找到了 把上一个节点的 next 指向当前节点的下一个节点
                nNode.pre.next = nNode.next ;
                nNode = null ;

                //移除队列
                QUEUE.poll();

                return nNode;
            }

            nNode = nNode.next ;
        }

        return super.remove(key);
    }

    /**
     * 增加size
     */
    private void sizeUp(){

        //在put值时候认为里边已经有数据了
        flag = true ;

        if (size == null){
            size = new AtomicInteger() ;
        }
        int size = this.size.incrementAndGet();
        if (size >= MAX_SIZE) {
            //找到队列头的数据
            Node node = QUEUE.poll() ;
            if (node == null){
                throw new RuntimeException("data error") ;
            }

            //移除该 key
            Object key = node.key ;
            remove(key) ;
            lruCallback() ;
        }

    }

    /**
     * 数量减小
     */
    private void sizeDown(){

        if (QUEUE.size() == 0){
            flag = false ;
        }

        this.size.decrementAndGet() ;
    }

    @Override
    public int size() {
        return size.get() ;
    }

    /**
     * 链表
     */
    private class Node{
        private Node next ;
        private Node pre ;
        private Object key ;
        private Object val ;
        private Long updateTime ;

        public Node(Node pre,Node next, Object key, Object val) {
            this.pre = pre ;
            this.next = next;
            this.key = key;
            this.val = val;
            this.updateTime = System.currentTimeMillis() ;
        }

        public void setUpdateTime(Long updateTime) {
            this.updateTime = updateTime;
        }

        public Long getUpdateTime() {
            return updateTime;
        }

        @Override
        public String toString() {
            return "Node{" +
                    "key=" + key +
                    ", val=" + val +
                    '}';
        }
    }


    /**
     * copy HashMap 的 hash 实现
     * @param key
     * @return
     */
    public int hash(Object key) {
        int h;
        return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
    }

    private void lruCallback(){
        LOGGER.debug("lruCallback");
    }


    private class CheckTimeThread implements Runnable{

        @Override
        public void run() {
            while (flag){
                try {
                    Node node = QUEUE.poll();
                    if (node == null){
                        continue ;
                    }
                    Long updateTime = node.getUpdateTime() ;

                    if ((updateTime - System.currentTimeMillis()) >= EXPIRE_TIME){
                        remove(node.key) ;
                    }
                } catch (Exception e) {
                    LOGGER.error("InterruptedException");
                }
            }
        }
    }

}

代码看着比较多,其实实现的思路还是比较简单:

  • 采用了与 HashMap 一样的保存数据方式,只是自己手动实现了一个简易版。
  • 内部采用了一个队列来保存每次写入的数据。
  • 写入的时候判断缓存是否大于了阈值 N,如果满足则根据队列的 FIFO 特性将队列头的数据删除。因为队列头的数据肯定是最先放进去的。
  • 再开启了一个守护线程用于判断最先放进去的数据是否超期(因为就算超期也是最先放进去的数据最有可能满足超期条件

以上代码大体功能满足了,但是有一个致命问题。

就是最近最少使用没有满足,删除的数据都是最先放入的数据。

实现二

因此如何来实现一个完整的 LRU 缓存呢,这次不考虑过期时间的问题。

  • 要记录最近最少使用,那至少需要一个有序的集合来保证写入的顺序。
  • 在使用了数据之后能够更新它的顺序。

基于以上两点很容易想到一个常用的数据结构:链表。

1.每次写入数据时将数据放入链表头结点。

  1. 使用数据时候将数据移动到头结点。
  2. 缓存数量超过阈值时移除链表尾部数据。
public class LRUCache {
    
    HashMap<Integer, Node> map =new HashMap<>();
    int maximum;
    Node head, tail;
    public LRUCache(int capacity) {
        this.maximum=capacity;
        head=new Node(0,0);
        tail=new Node(0,0);
        head.next=tail;
        tail.pre=head;
    }
    public void addToHead(Node node){
        node.next=head.next;
        node.pre=head;
        head.next.pre=node;
        head.next=node;
        map.put(node.key,node);
    }
    public void delete(Node node){
        node.pre.next=node.next;
        node.next.pre=node.pre;
        map.remove(node.key);
    }
    public int get(int key) {
        Node node = map.get(key);
        if(node!=null){
            delete(node);
            addToHead(node);
            return node.value;
        }else {
            return -1;
        }
        
    }
    
    public void set(int key, int value) {
        Node node=map.get(key);
        if(node!=null){
            node.value=value;
            delete(node);
            addToHead(node);
        }else{
            node =new Node(key,value);
            if(map.size()<maximum){
                addToHead(node);
            }else{
                delete(tail.pre);
                addToHead(node);
            }
        }
    }
}
class Node{
    int key;
    int value;
    Node pre;
    Node next;
    public Node(int key, int value){
        this.key=key;
        this.value=value;
    }
}

实现3

其实如果对 Java 的集合比较熟悉的话,会发现上文的结构和 LinkedHashMap 非常类似。

public class LRULinkedMap<K,V> {


    /**
     * 最大缓存大小
     */
    private int cacheSize;

    private LinkedHashMap<K,V> cacheMap ;


    public LRULinkedMap(int cacheSize) {
        this.cacheSize = cacheSize;

        cacheMap = new LinkedHashMap(16,0.75F,true){
            @Override
            protected boolean removeEldestEntry(Map.Entry eldest) {
                if (cacheSize + 1 == cacheMap.size()){
                    return true ;
                }else {
                    return false ;
                }
            }
        };
    }

    public void put(K key,V value){
        cacheMap.put(key,value) ;
    }

    public V get(K key){
        return cacheMap.get(key) ;
    }


    public Collection<Map.Entry<K, V>> getAll() {
        return new ArrayList<Map.Entry<K, V>>(cacheMap.entrySet());
    }
}

LinkedHashMap 内部也有维护一个双向队列,在初始化时也会给定一个缓存大小的阈值。初始化时自定义是否需要删除最近不常使用的数据,如果是则会按照实现二中的方式管理数据。

其实主要代码就是重写了 LinkedHashMap 的 removeEldestEntry 方法:

protected boolean removeEldestEntry(Map.Entry<K,V> eldest) {
    return false;
}

它默认是返回 false,也就是不会管有没有超过阈值。

所以我们自定义大于了阈值时返回 true,这样 LinkedHashMap 就会帮我们删除最近最少使用的数据。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容