题目描述
设计一个找到数据流中第K大元素的类(class)。注意是排序后的第K大元素,不是第K个不同的元素。
你的 KthLargest
类需要一个同时接收整数 k
和整数数组nums
的构造器,它包含数据流中的初始元素。每次调用 KthLargest.add
,返回当前数据流中第K大的元素。
示例:
int k = 3;
int[] arr = [4,5,8,2];
KthLargest kthLargest = new KthLargest(3, arr);
kthLargest.add(3); // returns 4
kthLargest.add(5); // returns 5
kthLargest.add(10); // returns 5
kthLargest.add(9); // returns 8
kthLargest.add(4); // returns 8
说明:
可以假设 nums
的长度≥ k-1
且k
≥ 1。
解
这道题被打上的标签是堆的一种数据结构,那么就用堆的思路去解决这道题。
堆的性质有两种:
- 左右节点的元素值不能大于父节点或者不能小于父节点的元素值;
- 堆是一种完全二叉树
先了解堆
上面的数据结构是一种二叉堆,更多的堆可以是三叉堆甚至d叉堆。。。
- 最小堆:每个父节点都比左右节点的元素值要小
- 最大堆:每个父节点都比左右节点的元素值要大
移除某个节点和堆顶一样的,判断条件就是左右节点有没有比父节点的元素更小的值。
添加节点也是一样的,因为堆是一种完全二叉树,底层采用数组的形式,添加节点加在数组的末尾,然后和父节点比较,如果比父节点小的则交换,直到堆顶。如果没有比父节点小的则停止交换。
过程
题目要求是找到排序后的第K大个元素值,可以使用只有K长度的最小堆,堆顶的元素值则是最小的。
添加元素并返回堆顶元素值
kthLargest.add(5)
kthLargest.add(10)
最终代码如下:
class KthLargest {
private int k;
MinHeap<Integer> arr;
public KthLargest(int k, int[] nums) {
this.k = k;
// 只有k长度的最小堆
arr = new MinHeap<>(k);
for (int num : nums) {
add(num);
}
}
public int add(int val) {
arr.heapify(val);
return arr.peek();
}
/**
* 自定义最小堆,底层采用数组数据结构
* E 数据类型需要继承Comparable,具有数值比较的功能
*/
private class MinHeap<E extends Comparable<E>> {
private Array<E> data;
public MinHeap() {
data = new Array<>();
}
public MinHeap(int capacity) {
data = new Array<>(capacity);
}
/**
* 0
* 1 2
* 3 4 5 6
* 父节点 i
* 左节点 2 * i + 1
* 右节点 2 * i + 2
*
* @param index
* @return 返回index父节点的索引
*/
private int parent(int index) {
if (index == 0) {
throw new IllegalArgumentException("索引为0没有父节点");
}
return (index - 1) / 2;
}
private int leftChild(int index) {
return index * 2 + 1;
}
public void add(E e) {
// 使用数组数据结构添加元素置于末尾
data.add(e);
siftUp(data.getSize() - 1);
}
/**
* 最小堆的性质:左右节点的元素值不能大于父节点的元素值
*
* @param index 当前索引
*/
private void siftUp(int index) {
// 当前节点与父结点进行比较,当前节点比父节点小则进行交换
while (index > 0 && data.get(parent(index)).compareTo(data.get(index)) > 0) {
data.swap(index, parent(index));
index = parent(index);
}
}
/**
* 这里和Java中的PriorityQueue类 heapify 方法实现同样的效果
* ★★★★★
* 根据题意
* 如果待比较的元素比最小堆堆顶大,则替换
*
* @param e 待比较的元素值
*/
public void heapify(E e) {
if (peek() == null) {
data.add(e);
return;
} else if (data.getSize() < k) {
data.add(e);
siftUp(data.getSize() - 1);
} else if (peek().compareTo(e) < 0) {
data.set(e);
// 从堆顶开始
siftDown(0);
}
}
/**
* @param index
*/
private void siftDown(int index) {
// while循环,依据条件是是否存在左孩子
while (leftChild(index) < data.getSize()) {
int j = leftChild(index);
// 寻找左右孩子的最小的元素值
if (j + 1 < data.getSize() && data.get(j).compareTo(data.get(j + 1)) > 0) {
// 存在右孩子而且右孩子的元素值比左孩子的元素值更小
j++;
}
if (data.get(index).compareTo(data.get(j)) < 0) {
break;
}
// 运行到此处说明存在左右孩子的元素值比父节点更小
data.swap(index, j);
index = j;
}
}
/**
* @return 返回顶节点
*/
public E peek() {
if (data.getSize() == 0) {
return null;
}
return data.get(0);
}
public int getSize() {
return data.getSize();
}
public boolean isEmpty() {
// return getSize() == 0 ? true: false;
return data.isEmpty();
}
}
/**
* 自定义数组,数据类型为E
*/
private class Array<E> {
private E[] data;
private int size;
public Array(int capacity) {
data = (E[]) new Object[capacity];
size = 0;
}
public Array() {
// 数组长度默认为10
this(10);
}
/**
* @param index 当前索引
* @return 返回元素值
*/
public E get(int index) {
if (index < 0 || index > size - 1) {
throw new IllegalArgumentException("数组越界");
}
return data[index];
}
public void set(E e) {
set(0, e);
}
public void set(int index, E e) {
data[index] = e;
}
/**
* 添加元素
*
* @param index 索引
* @param e 元素值
*/
public void add(int index, E e) {
// 为什么index > size 而不是 index > size - 1
// 是因为在末位添加数据的时候刚好是在size上面,所以index最大为size
if (index < 0 || index > size) {
throw new IllegalArgumentException("index索引不在数组范围内");
}
// 扩容
if (size == data.length) {
resize(2 * data.length);
}
// 将index索引后面的数值往后移一位
for (int i = size; i > index; i--) {
data[i] = data[i - 1];
}
data[index] = e;
size++;
}
/**
* 末尾添加元素
*
* @param e 元素值
*/
public void add(E e) {
add(size, e);
}
/**
* 删除元素
*
* @param index 索引
* @return 返回被删除的元素
*/
public E remove(int index) {
if (index < 0 || index > size - 1) {
throw new IllegalArgumentException("index索引不在数组范围内");
}
E ret = data[index];
// index索引后的元素往前移动一位
for (int i = index; i < size - 1; i++) {
data[index] = data[index + 1];
}
size--;
data[size - 1] = null;
// 缩容,而且数组为1的时候就不能再除2 ,1/2 = 0,长度为0 的数组就不能扩容
if (size == data.length / 4 && data.length / 2 != 0) {
resize(data.length / 2);
}
return ret;
}
/**
* 输出末尾元素
*
* @return 返回被删除的元素
*/
public E remove() {
return remove(size - 1);
}
private void resize(int newCapacity) {
E[] newData = (E[]) new Object[newCapacity];
for (int i = 0; i < size; i++) {
newData[i] = data[i];
}
data = newData;
}
/**
* @return 返回数组的长度
*/
public int getSize() {
return size;
}
public boolean isEmpty() {
return size == 0 ? true : false;
}
public void swap(int index1, int index2) {
E tmp = data[index1];
data[index1] = data[index2];
data[index2] = tmp;
}
}
}