数据结构(14)二叉堆

Heap.java


public interface Heap<E> {
    int size(); // 元素的数量
    boolean isEmpty();  // 是否为空
    void clear();   // 清空
    void add(E element);     // 添加元素
    E get();    // 获得堆顶元素
    E remove(); // 删除堆顶元素
    E replace(E element); // 删除堆顶元素的同时插入一个新元素
}

AbstractHeap.java

@SuppressWarnings("unchecked")
public abstract class AbstractHeap<E> implements Heap<E> {
    protected int size;
    protected Comparator<E> comparator;
    
    public AbstractHeap(Comparator<E> comparator) {
        this.comparator = comparator;
    }
    
    public AbstractHeap() {
        this(null);
    }
    
    @Override
    public int size() {
        return size;
    }

    @Override
    public boolean isEmpty() {
        return size == 0;
    }
    
    protected int compare(E e1, E e2) {
        return comparator != null ? comparator.compare(e1, e2) 
                : ((Comparable<E>)e1).compareTo(e2);
    }
}

BinaryHeap.java

/**
 * 二叉堆(最大堆)
 *
 * @param <E>
 */
@SuppressWarnings("unchecked")
public class BinaryHeap<E> extends AbstractHeap<E> implements BinaryTreeInfo {
    private E[] elements;
    private static final int DEFAULT_CAPACITY = 10;
    
    public BinaryHeap(E[] elements, Comparator<E> comparator)  {
        super(comparator);
        
        if (elements == null || elements.length == 0) {
            this.elements = (E[]) new Object[DEFAULT_CAPACITY];
        } else {
            size = elements.length;
            int capacity = Math.max(elements.length, DEFAULT_CAPACITY);
            this.elements = (E[]) new Object[capacity];
            for (int i = 0; i < elements.length; i++) {
                this.elements[i] = elements[i];
            }
            heapify();
        }
    }
    
    public BinaryHeap(E[] elements)  {
        this(elements, null);
    }
    
    public BinaryHeap(Comparator<E> comparator) {
        this(null, comparator);
    }
    
    public BinaryHeap() {
        this(null, null);
    }

    @Override
    public void clear() {
        for (int i = 0; i < size; i++) {
            elements[i] = null;
        }
        size = 0;
    }

    @Override
    public void add(E element) {
        elementNotNullCheck(element);
        ensureCapacity(size + 1);
        elements[size++] = element;
        siftUp(size - 1);
    }

    @Override
    public E get() {
        emptyCheck();
        return elements[0];
    }

    @Override
    public E remove() {
        emptyCheck();
        
        int lastIndex = --size;
        E root = elements[0];
        elements[0] = elements[lastIndex];
        elements[lastIndex] = null;
        
        siftDown(0);
        return root;
    }

    @Override
    public E replace(E element) {
        elementNotNullCheck(element);
        
        E root = null;
        if (size == 0) {
            elements[0] = element;
            size++;
        } else {
            root = elements[0];
            elements[0] = element;
            siftDown(0);
        }
        return root;
    }
    
    /**
     * 批量建堆
     */
    private void heapify() {
        // 自上而下的上滤
//      for (int i = 1; i < size; i++) {
//          siftUp(i);
//      }
        
        // 自下而上的下滤
        for (int i = (size >> 1) - 1; i >= 0; i--) {
            siftDown(i);
        }
    }
    
    /**
     * 让index位置的元素下滤
     * @param index
     */
    private void siftDown(int index) {
        E element = elements[index];
        int half = size >> 1;
        // 第一个叶子节点的索引 == 非叶子节点的数量
        // index < 第一个叶子节点的索引
        // 必须保证index位置是非叶子节点
        while (index < half) { 
            // index的节点有2种情况
            // 1.只有左子节点
            // 2.同时有左右子节点
            
            // 默认为左子节点跟它进行比较
            int childIndex = (index << 1) + 1;
            E child = elements[childIndex];
            
            // 右子节点
            int rightIndex = childIndex + 1;
            
            // 选出左右子节点最大的那个
            if (rightIndex < size && compare(elements[rightIndex], child) > 0) {
                child = elements[childIndex = rightIndex];
            }
            
            if (compare(element, child) >= 0) break;

            // 将子节点存放到index位置
            elements[index] = child;
            // 重新设置index
            index = childIndex;
        }
        elements[index] = element;
    }
    
    /**
     * 让index位置的元素上滤
     * @param index
     */
    private void siftUp(int index) {
//      E e = elements[index];
//      while (index > 0) {
//          int pindex = (index - 1) >> 1;
//          E p = elements[pindex];
//          if (compare(e, p) <= 0) return;
//          
//          // 交换index、pindex位置的内容
//          E tmp = elements[index];
//          elements[index] = elements[pindex];
//          elements[pindex] = tmp;
//          
//          // 重新赋值index
//          index = pindex;
//      }
        E element = elements[index];
        while (index > 0) {
            int parentIndex = (index - 1) >> 1;
            E parent = elements[parentIndex];
            if (compare(element, parent) <= 0) break;
            
            // 将父元素存储在index位置
            elements[index] = parent;
            
            // 重新赋值index
            index = parentIndex;
        }
        elements[index] = element;
    }
    
    private void ensureCapacity(int capacity) {
        int oldCapacity = elements.length;
        if (oldCapacity >= capacity) return;
        
        // 新容量为旧容量的1.5倍
        int newCapacity = oldCapacity + (oldCapacity >> 1);
        E[] newElements = (E[]) new Object[newCapacity];
        for (int i = 0; i < size; i++) {
            newElements[i] = elements[i];
        }
        elements = newElements;
    }
    
    private void emptyCheck() {
        if (size == 0) {
            throw new IndexOutOfBoundsException("Heap is empty");
        }
    }
    
    private void elementNotNullCheck(E element) {
        if (element == null) {
            throw new IllegalArgumentException("element must not be null");
        }
    }

    @Override
    public Object root() {
        return 0;
    }

    @Override
    public Object left(Object node) {
        int index = ((int)node << 1) + 1;
        return index >= size ? null : index;
    }

    @Override
    public Object right(Object node) {
        int index = ((int)node << 1) + 2;
        return index >= size ? null : index;
    }

    @Override
    public Object string(Object node) {
        return elements[(int)node];
    }
}

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

友情链接更多精彩内容