红黑树

每次看算法书,遇到红黑树,我总是不由自主的跳过。一来觉得工作中好像用不到,二来也是因为懒,不想深入去纠结。今天,终于下定决心好好的了解一下红黑树的设计思想。

BinarySearchTree(二叉查找树)

首先,我们还是从最简单的二叉树说起。简单来说,就是左小右大(反着同理)。那么二叉树的问题在哪里呢?如果你将1-10000顺序插入一颗二叉树,大概你就会发现问题了。不平衡可能会造成查找深度的迅速增加,降低查找效率。

public class BinarySearchTree<T extends Comparable<? super T>> {

    private BinaryNode<T> root;  //root节点

    public BinarySearchTree() {
        this.root = null;
    }

    public void makeEmpty() {
        root = null;
    }

    public boolean isEmpty() {
        return root == null;
    }

    public boolean contain(T x) {
        return contain(x, root);
    }

    public T findMin() {
        if (isEmpty()) throw new IllegalArgumentException();
        return findMin(root).element;
    }

    public T findMax() {
        if (isEmpty()) throw new IllegalArgumentException();
        return findMax(root).element;
    }

    public void insert(T x) {
        root = insert(x, root);
    }

    public void remove(T x) {
        root = remove(x, root);
    }

    /**
     * Internal method to find an item in a subtree
     *
     * @param x is item to search for
     * @param t is the node that roots the subtree
     * @return node containing the mached item
     */
    private boolean contain(T x, BinaryNode<T> t) {
        if (t == null) {
            return false;
        }

        int compareResult = x.compareTo(t.element);
        if (compareResult < 0) {
            return contain(x, t.left);
        } else if (compareResult > 0) {
            return contain(x, t.right);
        } else {
            return true;
        }
    }

    /**
     * Internal method to find the smallest item in the subtree
     *
     * @param t the node that roots the subtree
     * @return the smallest item
     */
    private BinaryNode<T> findMin(BinaryNode<T> t) {
        if (t == null) {
            return null;
        } else if (t.left == null) {
            return t;
        } else {
            return findMin(t.left);
        }
    }

    /**
     * Internal method to find the largest item in the subtree
     *
     * @param t the node that roots the subtree
     * @return the largest item
     */
    private BinaryNode<T> findMax(BinaryNode<T> t) {
        if (t != null) {
            while (t.right != null) {
                t = t.right;
            }
        }
        return t;
    }

    /**
     * Internal method to insert into the subtree
     *
     * @param x the item to insert
     * @param t the node that roots the subtree
     * @return the new root of the subtree
     */
    private BinaryNode<T> insert(T x, BinaryNode<T> t) {
        if (t == null) {
            return new BinaryNode<T>(x, null, null);
        }

        int compareResult = x.compareTo(t.element);

        if (compareResult < 0) {
            t.left = insert(x, t.left);
        } else if (compareResult > 0) {
            t.right = insert(x, t.right);
        }

        return t;
    }

    /**
     * Internal method to remove from a subtree
     *
     * @param x the item to remove
     * @param t the node that roots the subtree
     * @return the new root of the subtree
     */
    private BinaryNode<T> remove(T x, BinaryNode<T> t) {
        if (t == null) {
            return t;
        }

        int compareResult = x.compareTo(t.element);

        if (compareResult < 0) {
            t.left = remove(x, t.left);
        } else if (compareResult > 0) {
            t.right = remove(x, t.right);
        } else if (t.left != null && t.right != null) {
            t.element = findMin(t.right).element;
            t.right = remove(t.element, t.right);
        } else {
            t = t.left != null ? t.left : t.right;
        }
        return t;
    }

    /**
     * 查找二叉树节点类
     *
     * @param <T>
     */
    private static class BinaryNode<T> {
        private T element;
        private BinaryNode<T> left;
        private BinaryNode<T> right;

        public BinaryNode(T element) {
            this(element, null, null);
        }

        public BinaryNode(T element, BinaryNode<T> left, BinaryNode<T> right) {
            this.element = element;
            this.left = left;
            this.right = right;
        }
    }
}

AvlTree

接下来我们找到了另一种平衡树,AVL Tree。这是一种高度平衡的树,通过插入后的旋转,我们可以保持树的的平衡。
旋转的四种情况(与代码对应):

  • rotateWithLeftChild


    left
  • rotateRightChild


    rr
  • doubleWithLeftChild


    lr
  • doubleWithRightChild


    rl
public class AvlTree<T extends Comparable<T>> {

    private AvlNode<T> root;

    public AvlTree() {
        this.root = null;
    }

    /**
     * the height of the tree
     *
     * @return
     */
    public int height() {
        return height(root);
    }

    public boolean isEmpty() {
        return root == null;
    }

    public AvlNode<T> insert(T x) {
        return insert(x, root);
    }

    public AvlNode<T> remove(T x) {
        return remove(x, root);
    }

    public AvlNode<T> find(T x) {
        return find(x, root);
    }

    public AvlNode<T> max() {
        return findMax(root);
    }

    public AvlNode<T> min() {
        return findMin(root);
    }

    public void printTree() {
        printTree(root);
    }

    /**
     * 插入节点,键值相同不作处理
     *
     * @param x
     * @param t
     * @return
     */
    private AvlNode<T> insert(T x, AvlNode<T> t) {
        if (t == null) {
            return new AvlNode<T>(x);
        }

        int compareResult = x.compareTo(t.element);
        if (compareResult < 0) {
            t.left = insert(x, t.left);
            if (height(t.left) - height(t.right) == 2) {
                if (x.compareTo(t.left.element) < 0) {
                    t = rotateWithLeftChild(t);
                } else {
                    t = doubleWithLeftChild(t);
                }
            }

        } else if (compareResult > 0) {
            t.right = insert(x, t.right);
            if (height(t.right) - height(t.left) == 2) {
                if (x.compareTo(t.right.element) > 0) {
                    t = rotateRightChild(t);
                } else {
                    t = doubleWithRightChild(t);
                }
            }
        }
        t.height = Math.max(height(t.left), height(t.right)) + 1;
        return t;
    }

    /**
     * 删除节点
     *
     * @param x
     * @param t
     * @return
     */
    private AvlNode<T> remove(T x, AvlNode<T> t) {
        if (t == null || x == null) {
            return null;
        }
        int cmp = x.compareTo(t.element);
        if (cmp < 0) {
            t.left = remove(x, t.left);
            if (height(t.right) - height(t.left) == 2) {
                AvlNode<T> tmp = t.right;
                if (height(tmp.left) > height(tmp.right)) {
                    t = doubleWithRightChild(tmp);
                } else {
                    t = rotateRightChild(tmp);
                }
            }
        } else if (cmp > 0) {
            t.right = remove(x, t.right);
            if (height(t.left) - height(t.right) == 2) {
                AvlNode<T> tmp = t.left;
                if (height(tmp.left) > height(tmp.right)) {
                    t = rotateWithLeftChild(tmp);
                } else {
                    t = doubleWithLeftChild(tmp);
                }
            }
        } else {
            if (t.left != null && t.right != null) {
                if (height(t.left) > height(t.right)) {
                    AvlNode<T> max = findMax(t.left);
                    t.element = max.element;
                    t.left = remove(max.element, t.left);
                } else {
                    AvlNode<T> min = findMin(t.right);
                    t.element = min.element;
                    t.right = remove(min.element, t.right);
                }
            } else {
                t = t.left != null ? t.left : t.right;
            }
        }
        return t;
    }

    private int height(AvlNode<T> t) {
        return t == null ? -1 : t.height;
    }

    private AvlNode findMin(AvlNode<T> t) {
        if (t == null)
            return t;

        while (t.left != null)
            t = t.left;
        return t;
    }

    private AvlNode findMax(AvlNode<T> t) {
        if (t == null)
            return t;

        while (t.right != null)
            t = t.right;
        return t;
    }

    /**
     * 查找
     *
     * @param x
     * @param t
     * @return
     */
    private AvlNode find(T x, AvlNode<T> t) {
        while (t != null)
            if (x.compareTo(t.element) < 0)
                t = t.left;
            else if (x.compareTo(t.element) > 0)
                t = t.right;
            else
                return t;    // Match

        return null;   // No match
    }

    private void printTree(AvlNode<T> t) {
        if (t != null) {
            printTree(t.left);
            System.out.println(t.element);
            printTree(t.right);
        }
    }

    /**
     * LL左左
     *
     * @param k2
     * @return
     */
    private AvlNode<T> rotateWithLeftChild(AvlNode<T> k2) {
        AvlNode<T> k1 = k2.left;
        k2.left = k1.right;
        k1.right = k2;
        k2.height = Math.max(height(k2.left), height(k2.right)) + 1;
        k1.height = Math.max(height(k1.left), k2.height) + 1;
        return k1;
    }

    /**
     * RR右右
     *
     * @param k1
     * @return
     */
    private AvlNode<T> rotateRightChild(AvlNode<T> k1) {
        AvlNode<T> k2 = k1.right;
        k1.right = k2.left;
        k2.left = k1;
        k1.height = Math.max(height(k1.left), height(k1.right)) + 1;
        k2.height = Math.max(height(k2.left), k1.height) + 1;
        return k2;
    }

    /**
     * LR
     *
     * @param k3
     * @return
     */
    private AvlNode<T> doubleWithLeftChild(AvlNode<T> k3) {
        k3.left = rotateRightChild(k3.left);
        return rotateWithLeftChild(k3);
    }

    /**
     * RL
     *
     * @param k2
     * @return
     */
    private AvlNode<T> doubleWithRightChild(AvlNode<T> k2) {
        k2.right = rotateWithLeftChild(k2.right);
        return rotateRightChild(k2);
    }

    private static class AvlNode<T> {
        private T element;
        private AvlNode<T> left;
        private AvlNode<T> right;
        private int height;

        public AvlNode(T element) {
            this(element, null, null);
        }

        public AvlNode(T element, AvlNode<T> left, AvlNode<T> right) {
            this.element = element;
            this.left = left;
            this.right = right;
            this.height = 0;
        }
    }
}

红黑树

由于AVL Tree的remove性能比较低,而且AVL的结构相较RB-Tree来说更为平衡,在插入和删除node更容易引起Tree的unbalance,因此在大量数据需要插入或者删除时,AVL需要rebalance的频率会更高。因此,RB-Tree在需要大量插入和删除node的场景下,效率更高。当然,由于AVL高度平衡,因此AVL的search效率更高。
在学习RB-Tree之前,我们还要来看一看它和另一种树——2-3 Tree——的对应关系。


2-3 Tree

红链代表3节点,黑链代表2节点。这样一看,瞬间就清晰明了了。
红黑树只有两种旋转

  • rotateLeft


    l
  • rotateRight


    r
    r
public class RedBlackTree<T extends Comparable<T>> {

    private Node<T> root;

    public void deleteMin(){
        if(!isRed(root.left) && !isRed(root.right)) root.color = true;
        root = deleteMin(root);
        if(!isEmpty()) root.color = false;
    }

    public void deleteMax(){
        if(!isRed(root.left) && !isRed(root.right)) root.color = true;
        root = deleteMax(root);
        if(!isEmpty()) root.color = false;
    }
    
    public void put(T x){
        root =  put(x,root);
        root.color = false;
    }
    
    
    public void delete(T x){
        if(!isRed(root.left) && !isRed(root.right)) root.color = true;
        root =  delete(x,root);
        if(!isEmpty()) root.color = false;
    }

    public boolean isEmpty() {
        return root == null;
    }
    
    /**
     * 指向该节点链接的颜色是否为红色
     * @param h
     * @return
     */
    private boolean isRed(Node<T> h){
        if(h== null) return false;
        return h.color;
    }
    
    private void flipColors(Node<T> h){
        h.color = !h.color;
        h.left.color = !h.left.color;
        h.right.color = !h.right.color;
    }

    /**
     * 插入
     * @param t
     * @param h
     * @return
     */
    private Node<T> put(T t,Node<T> h){
        if( h == null) return new Node<>(t,1,true);
        int cmp = t.compareTo(h.element);
        if(cmp < 0){
            h.left = put(t,h.left);
        }else if(cmp > 0){
            h.right = put(t,h.right);
        }else{
            h.element = t;
        }
        
        if(isRed(h.right) && !isRed(h.left)) h = rotateLeft(h);
        if(isRed(h.left) && isRed(h.left.left)) h = rotateRight(h);
        if(isRed(h.left) && isRed(h.right))  flipColors(h);
        
        h.size = h.left.size + h.right.size + 1;
        return h;
    }

    /**
     * 查找
     * @param x
     * @param h
     * @return
     */
    private Node<T> find(T x,Node<T> h){
//        while(x != null) {
//            int cmp = x.compareTo(h.element);
//            if(cmp < 0) {
//                h = h.left;
//            }
//            else if(cmp > 0) {
//                h = h.right;
//            }
//            else {
//                return h;
//            }
//        }
//        return null;
        
        if(h == null) return null;
        int cmp = x.compareTo(h.element);
        if(cmp < 0){
            return find(x,h.left);
        }else if( cmp > 0){
            return find(x,h.right);
        }else{
            return h;
        }
    }
    
    private Node<T> delete(T x,Node<T> h){
        if(x.compareTo(h.element) < 0) {
            if(!isRed(h.left) && !isRed(h.left.left)) {
                h = moveRedLeft(h);
            }
            h.left = delete(x,h.left);
        }
        else {
            if(isRed(h.left)) {
                h = rotateRight(h);
            }
            if (x.compareTo(h.element) == 0 && (h.right == null)) {
                return null;
            }
            if (!isRed(h.right) && !isRed(h.right.left)) {
                h = moveRedRight(h);
            }
            if (x.compareTo(h.element) == 0) {
                Node<T> tmp = min(h.right);
                h.element = tmp.element;
                h.right = deleteMin(h.right);
            }
            else {
                h.right = delete(x,h.right);
            }
        }
        return balance(h);
    }
    
    private Node<T> deleteMin(Node<T> h){
        if(h.left == null) return null;
        if (!isRed(h.left) && !isRed(h.left.left)) {
            h = moveRedLeft(h);
        }

        h.left = deleteMin(h.left);
        return balance(h);
    }

    private Node<T> deleteMax(Node<T> h) {
        if (isRed(h.left))
            h = rotateRight(h);

        if (h.right == null)
            return null;

        if (!isRed(h.right) && !isRed(h.right.left))
            h = moveRedRight(h);

        h.right = deleteMax(h.right);

        return balance(h);
    }

    private Node<T> min(Node<T> h) {
        if (h.left == null) {
            return h;
        }
        else {
            return min(h.left);
        }
    }

    private Node<T> max(Node<T> h) {
        if (h.right == null) {
            return h;
        }
        else {
            return max(h.right);
        }
    }

   
    
    private Node<T> moveRedRight(Node<T> h){
        flipColors(h);
        if (isRed(h.left.left)) {
            h = rotateRight(h);
            flipColors(h);
        }
        return h;
    }
    
    private Node<T> moveRedLeft(Node<T> h){
        flipColors(h);
        if (isRed(h.right.left)) {
            h.right = rotateRight(h.right);
            h = rotateLeft(h);
            flipColors(h);
        }
        return h;
    }
    
    
    
    private Node<T> rotateLeft(Node<T> h){
        Node<T> x = h.right;
        h.right = x.left;
        x.left = h;
        x.color = h.color;
        h.color = true;
        x.size = h.size;
        h.size = 1 + h.left.size + h.right.size;
        return x;
    }
    
    private Node<T> rotateRight(Node<T> h){
        Node<T> x = h.left;
        h.left = x.right;
        x.right = h;
        x.color = h.color;
        h.color = true;
        x.size = h.size;
        h.size = 1 + h.left.size + h.right.size;
        return x;
    }
    
    private Node<T> balance(Node<T> h){
        if (isRed(h.right)) {
            h = rotateLeft(h);
        }
        if (isRed(h.left) && isRed(h.left.left)) {
            h = rotateRight(h);
        }
        if (isRed(h.left) && isRed(h.right)) {
            flipColors(h);
        }

        h.size = h.left.size + h.right.size + 1;
        return h;
        
    }
    
    
    private class Node<T>{
        private T element;
        private Node<T> left;
        private Node<T> right;
        private int size;
        private boolean color;

        public Node(T element, int size, boolean color) {
            this.element = element;
            this.size = size;
            this.color = color;
        }
    }
}
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容