LRU 是 Least Recently Used 的简写,字面意思则是最近最少使用。
通常用于缓存的淘汰策略实现,由于缓存的内存非常宝贵,所以需要根据某种规则来剔除数据保证内存不被撑满。
如常用的 Redis 就有以下几种策略:
image.png
有一道面试题,大概需求是:
- 实现一个 LRU 缓存,当缓存数据达到 N 之后需要淘汰掉最近最少使用的数据。
- N 小时之内没有被访问的数据也需要淘汰掉。
package CacheTest;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Set;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Function:
*
* 1.在做 key 生成 hashcode 时是用的 HashMap 的 hash 函数
* 2.在做 put get 时,如果存在 key 相等时候为了简单没有去比较 equal 和 hashcode
* 3.限制大小, map的最大size是1024, 超过1024后,就淘汰掉最久没有访问的kv 键值对, 当淘汰时,需要调用一个callback lruCallback(K key, V value)
* 是利用每次 put 都将值写入一个内部队列,这样只需要判断队列里的第一个即可。
* 4.具备超时功能, 当键值对1小时内没有被访问, 就被淘汰掉, 当淘汰时, 需要调用一个callback timeoutCallback(K key, V value);
* 超时同理,单独开启一个守护进程来处理,取的是队列里的第一个 因为第一个是最早放进去的。
*
* 但是像 HashMap 里的扩容,链表在超过阈值之类的没有考虑进来。
*
*/
public class LRUAbstractMap extends java.util.AbstractMap {
private final static Logger LOGGER = LoggerFactory.getLogger(LRUAbstractMap.class);
/**
* 检查是否超期线程
*/
private ExecutorService checkTimePool ;
/**
* map 最大size
*/
private final static int MAX_SIZE = 1024 ;
private final static ArrayBlockingQueue<Node> QUEUE = new ArrayBlockingQueue<>(MAX_SIZE) ;
/**
* 默认大小
*/
private final static int DEFAULT_ARRAY_SIZE =1024 ;
/**
* 数组长度
*/
private int arraySize ;
/**
* 数组
*/
private Object[] arrays ;
/**
* 判断是否停止 flag
*/
private volatile boolean flag = true ;
/**
* 超时时间
*/
private final static Long EXPIRE_TIME = 60 * 60 * 1000L ;
/**
* 整个 Map 的大小
*/
private volatile AtomicInteger size ;
public LRUAbstractMap() {
arraySize = DEFAULT_ARRAY_SIZE;
arrays = new Object[arraySize] ;
//开启一个线程检查最先放入队列的值是否超期
executeCheckTime();
}
/**
* 开启一个线程检查最先放入队列的值是否超期 设置为守护线程
*/
private void executeCheckTime() {
ThreadFactory namedThreadFactory = new ThreadFactoryBuilder()
.setNameFormat("check-thread-%d")
.setDaemon(true)
.build();
checkTimePool = new ThreadPoolExecutor(1, 1, 0L, TimeUnit.MILLISECONDS,
new ArrayBlockingQueue<>(1),namedThreadFactory,new ThreadPoolExecutor.AbortPolicy());
checkTimePool.execute(new CheckTimeThread()) ;
}
@Override
public Set<Entry> entrySet() {
return super.keySet();
}
@Override
public Object put(Object key, Object value) {
int hash = hash(key);
int index = hash % arraySize ;
Node currentNode = (Node) arrays[index] ;
if (currentNode == null){
arrays[index] = new Node(null,null, key, value);
//写入队列
QUEUE.offer((Node) arrays[index]) ;
sizeUp();
}else {
Node cNode = currentNode ;
Node nNode = cNode ;
//存在就覆盖
if (nNode.key == key){
cNode.val = value ;
}
while (nNode.next != null){
//key 存在 就覆盖 简单判断
if (nNode.key == key){
nNode.val = value ;
break ;
}else {
//不存在就新增链表
sizeUp();
Node node = new Node(nNode,null,key,value) ;
//写入队列
QUEUE.offer(currentNode) ;
cNode.next = node ;
}
nNode = nNode.next ;
}
}
return null ;
}
@Override
public Object get(Object key) {
int hash = hash(key) ;
int index = hash % arraySize ;
Node currentNode = (Node) arrays[index] ;
if (currentNode == null){
return null ;
}
if (currentNode.next == null){
//更新时间
currentNode.setUpdateTime(System.currentTimeMillis());
//没有冲突
return currentNode ;
}
Node nNode = currentNode ;
while (nNode.next != null){
if (nNode.key == key){
//更新时间
currentNode.setUpdateTime(System.currentTimeMillis());
return nNode ;
}
nNode = nNode.next ;
}
return super.get(key);
}
@Override
public Object remove(Object key) {
int hash = hash(key) ;
int index = hash % arraySize ;
Node currentNode = (Node) arrays[index] ;
if (currentNode == null){
return null ;
}
if (currentNode.key == key){
sizeDown();
arrays[index] = null ;
//移除队列
QUEUE.poll();
return currentNode ;
}
Node nNode = currentNode ;
while (nNode.next != null){
if (nNode.key == key){
sizeDown();
//在链表中找到了 把上一个节点的 next 指向当前节点的下一个节点
nNode.pre.next = nNode.next ;
nNode = null ;
//移除队列
QUEUE.poll();
return nNode;
}
nNode = nNode.next ;
}
return super.remove(key);
}
/**
* 增加size
*/
private void sizeUp(){
//在put值时候认为里边已经有数据了
flag = true ;
if (size == null){
size = new AtomicInteger() ;
}
int size = this.size.incrementAndGet();
if (size >= MAX_SIZE) {
//找到队列头的数据
Node node = QUEUE.poll() ;
if (node == null){
throw new RuntimeException("data error") ;
}
//移除该 key
Object key = node.key ;
remove(key) ;
lruCallback() ;
}
}
/**
* 数量减小
*/
private void sizeDown(){
if (QUEUE.size() == 0){
flag = false ;
}
this.size.decrementAndGet() ;
}
@Override
public int size() {
return size.get() ;
}
/**
* 链表
*/
private class Node{
private Node next ;
private Node pre ;
private Object key ;
private Object val ;
private Long updateTime ;
public Node(Node pre,Node next, Object key, Object val) {
this.pre = pre ;
this.next = next;
this.key = key;
this.val = val;
this.updateTime = System.currentTimeMillis() ;
}
public void setUpdateTime(Long updateTime) {
this.updateTime = updateTime;
}
public Long getUpdateTime() {
return updateTime;
}
@Override
public String toString() {
return "Node{" +
"key=" + key +
", val=" + val +
'}';
}
}
/**
* copy HashMap 的 hash 实现
* @param key
* @return
*/
public int hash(Object key) {
int h;
return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
}
private void lruCallback(){
LOGGER.debug("lruCallback");
}
private class CheckTimeThread implements Runnable{
@Override
public void run() {
while (flag){
try {
Node node = QUEUE.poll();
if (node == null){
continue ;
}
Long updateTime = node.getUpdateTime() ;
if ((updateTime - System.currentTimeMillis()) >= EXPIRE_TIME){
remove(node.key) ;
}
} catch (Exception e) {
LOGGER.error("InterruptedException");
}
}
}
}
}
代码看着比较多,其实实现的思路还是比较简单:
- 采用了与 HashMap 一样的保存数据方式,只是自己手动实现了一个简易版。
- 内部采用了一个队列来保存每次写入的数据。
- 写入的时候判断缓存是否大于了阈值 N,如果满足则根据队列的 FIFO 特性将队列头的数据删除。因为队列头的数据肯定是最先放进去的。
- 再开启了一个守护线程用于判断最先放进去的数据是否超期(因为就算超期也是最先放进去的数据最有可能满足超期条件
以上代码大体功能满足了,但是有一个致命问题。
就是最近最少使用没有满足,删除的数据都是最先放入的数据。
实现二
因此如何来实现一个完整的 LRU 缓存呢,这次不考虑过期时间的问题。
- 要记录最近最少使用,那至少需要一个有序的集合来保证写入的顺序。
- 在使用了数据之后能够更新它的顺序。
基于以上两点很容易想到一个常用的数据结构:链表。
1.每次写入数据时将数据放入链表头结点。
- 使用数据时候将数据移动到头结点。
- 缓存数量超过阈值时移除链表尾部数据。
public class LRUCache {
HashMap<Integer, Node> map =new HashMap<>();
int maximum;
Node head, tail;
public LRUCache(int capacity) {
this.maximum=capacity;
head=new Node(0,0);
tail=new Node(0,0);
head.next=tail;
tail.pre=head;
}
public void addToHead(Node node){
node.next=head.next;
node.pre=head;
head.next.pre=node;
head.next=node;
map.put(node.key,node);
}
public void delete(Node node){
node.pre.next=node.next;
node.next.pre=node.pre;
map.remove(node.key);
}
public int get(int key) {
Node node = map.get(key);
if(node!=null){
delete(node);
addToHead(node);
return node.value;
}else {
return -1;
}
}
public void set(int key, int value) {
Node node=map.get(key);
if(node!=null){
node.value=value;
delete(node);
addToHead(node);
}else{
node =new Node(key,value);
if(map.size()<maximum){
addToHead(node);
}else{
delete(tail.pre);
addToHead(node);
}
}
}
}
class Node{
int key;
int value;
Node pre;
Node next;
public Node(int key, int value){
this.key=key;
this.value=value;
}
}
实现3
其实如果对 Java 的集合比较熟悉的话,会发现上文的结构和 LinkedHashMap 非常类似。
public class LRULinkedMap<K,V> {
/**
* 最大缓存大小
*/
private int cacheSize;
private LinkedHashMap<K,V> cacheMap ;
public LRULinkedMap(int cacheSize) {
this.cacheSize = cacheSize;
cacheMap = new LinkedHashMap(16,0.75F,true){
@Override
protected boolean removeEldestEntry(Map.Entry eldest) {
if (cacheSize + 1 == cacheMap.size()){
return true ;
}else {
return false ;
}
}
};
}
public void put(K key,V value){
cacheMap.put(key,value) ;
}
public V get(K key){
return cacheMap.get(key) ;
}
public Collection<Map.Entry<K, V>> getAll() {
return new ArrayList<Map.Entry<K, V>>(cacheMap.entrySet());
}
}
LinkedHashMap 内部也有维护一个双向队列,在初始化时也会给定一个缓存大小的阈值。初始化时自定义是否需要删除最近不常使用的数据,如果是则会按照实现二中的方式管理数据。
其实主要代码就是重写了 LinkedHashMap 的 removeEldestEntry 方法:
protected boolean removeEldestEntry(Map.Entry<K,V> eldest) {
return false;
}
它默认是返回 false,也就是不会管有没有超过阈值。
所以我们自定义大于了阈值时返回 true,这样 LinkedHashMap 就会帮我们删除最近最少使用的数据。