线段树Segment Tree

  • 线段树的构建
  • 查询
  • 更新
  1. 指针实现
public class SegmentTreeNode {
    public int start;
    public int end;
    public int max;
    public SegmentTreeNode left;
    public SegmentTreeNode right;

    public SegmentTreeNode(int start, int end, int max) {
        this.start = start;
        this.end = end;
        this.max = max;
        this.left = null;
        this.right = null;
    }

    //构建线段树
    public SegmentTreeNode build(int[] nums) {
        return buildHelper(0, nums.length - 1, nums);
    }

    public SegmentTreeNode buildHelper(int left, int right, int[] nums) {
        if (left > right) {
            return null;
        }
        //节点区间的值为其包含的元素的最大值,取左边界
        SegmentTreeNode root = new SegmentTreeNode(left, right, nums[left]);
        if (left == right) {
            //递归终止条件
            //如果只有一个元素,则其左右子节点为空,节点值为左边界值
            return root;
        }
        int mid = left + (right - left) / 2;
        root.left = buildHelper(left, mid, nums);
        root.right = buildHelper(mid + 1, right, nums);
        //根据左右子节点的max更新当前节点的max
        root.max = Math.max(root.left.max, root.right.max);
        return root;
    }

    //查询区间[left, right]的最大值
    public int query(SegmentTreeNode root, int start, int end) {
        if (root == null || start > root.end || end < root.start) {
            return Integer.MIN_VALUE;
        }
        //如果查询区间包含当前节点区间,则当前节点即为所求
        if (start <= root.start && root.end <= end) {
            return root.max;
        }
        //递归查询左右子节点
        int mid = root.start + (root.end - root.start) / 2;
        int leftMax = query(root.left, start, end);
        int rightMax = query(root.right, start, end);
        return Math.max(leftMax, rightMax);
    }
    //更新节点值 - 将索引index的值修改为val
    public void modify(SegmentTreeNode root, int index, int val) {
        if (root.start == root.end && root.start == index) {
            root.max = val;
            return;
        }
        //将当前区间分割为左右两个区间,mid为分割线
        int mid = root.start + (root.end - root.start)/ 2;
        //判断index落在哪个区间
        if (index <= mid) {
            //index在左子区间,递归更新左子区间
            modify(root.left, index, val);
            root.max = Math.max(root.left.max, root.right.max);
        } else {
            //index在右子区间,递归更新右子区间
            modify(root.right, index, val);
            root.max = Math.max(root.left.max, root.right.max);
        }
    }

}

  1. 数组实现 - 灵感来自正月点灯笼:
public class SegTree {
    private int[] data;
    private int[] tree;
    public SegTree(int[] nums) {
        this.data = nums;
        this.tree = new int[nums.length * 4];
        build(0, 0, nums.length - 1);
    }

    public void build(int node, int start, int end) {
        if (start == end) {
            tree[node] = data[start];
            return;
        }
        // 计算出当前节点的区间范围
        int mid = start + (end - start) / 2;
        int leftNode =  2 * node + 1;
        int rightNode = 2 * node + 2;
        build(leftNode, start, mid);
        build(rightNode, mid + 1, end);
        tree[node] = tree[leftNode] + tree[rightNode];
    }

    public void update(int node, int start, int end, int index, int val) {
        if (start == end){
            tree[node] = val;
            data[index] = val;
            return;
        }
        // 计算出当前节点的区间范围
        int mid = start + (end - start) / 2;
        int leftNode = 2 * node + 1;
        int rightNode = 2 * node + 2;
        if (index >= start && index <= mid) {
            update(leftNode, start, mid, index, val);
        } else {
            update(rightNode, mid + 1, end, index, val);
        }
        tree[node] = tree[leftNode] + tree[rightNode];
    }

    public int query(int node, int start, int end, int l, int r) {
        //node:节点编号;start,end:节点表示的区间;l,r:要查询的区间
        System.out.println("start: " +  start + " end: " +  end);
        if (start > r || end < l) {
            return 0;
        }
        if (l <= start && end <= r) {
            return tree[node];
        }
        int mid = start + (end - start) / 2;
        int leftNode =  2 * node + 1;
        int rightNode = 2 * node + 2;
        int sumLeft = query(leftNode, start, mid, l, r);
        int sumRight = query(rightNode, mid + 1, end, l, r);
        return sumLeft + sumRight;
    }

    public static void main(String[] args) {
        int[] nums = {1, 3, 5, 7, 9, 11};
        SegTree segTree = new SegTree(nums);
        for (int i = 0; i < segTree.tree.length; i++) {
            System.out.println(i + " - " + segTree.tree[i]);
        }

        segTree.update(0, 0, nums.length - 1, 0, 10);
        System.out.println("Update 0 to 10:=======");
        for (int i = 0; i < segTree.tree.length; i++) {
            System.out.println(i + " - " + segTree.tree[i]);
        }

        System.out.println("Query 0 to 2:=======");
        int query = segTree.query(0, 0, nums.length - 1, 0, 2);
        System.out.println(query);
    }
}


相关链接:
指针实现:构建、搜索、更新
https://baijiahao.baidu.com/s?id=1736339086704827934&wfr=spider&for=pc
数组实现
https://blog.csdn.net/myRealization/article/details/105130003
https://www.bilibili.com/video/BV1cb411t7AM/?spm_id_from=333.1007.top_right_bar_window_history.content.click&vd_source=ed89f81ec70f5a5933f8a8a3b71dbcc0

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容