一、什么是线段树
线段树是一种高级数据结构,可以用于解决区间的查询问题,比如可以查询最大值、最小值、最大公约数、最小公倍数等。线段树的树高为O(logn),所以查询和更新操作都是O(logn)。
二、线段树的特性
线段树的叶子结点都是具体的值,且是长度为1的区间,代表的是该闭合区间范围的聚合信息(如最小值、最大值、最小公约数等)。
若i不是叶子节点,则其左子树索引为2i+1,右子树索引为2i+2
。
由满二叉树的性质:
1、最后一层节点数为2^{h-1}
2、树所有节点总数为2^{h}-1
3、树的所有节点数约等于最后一层节点数和其他层节点数之和
所以,
1、当叶子结点个数(即,len(nums)=n)为2的k次方时,可以推断此时二叉树为满二叉树,对于nums来说,初始时就需要申请2n的空间来存储线段树。例如len(nums)=8,即A[0-7],按照图画出二叉树为满二叉树,需要16的长度空间(1+2+4+8=15)。
2、当叶子节点个数(即,len(nums)=n)为2的k次方+1时,就需要2n+2n=4n的空间存储线段树。例如len(nums)=10,如图中情况,就需要40的长度空间(1+2+4+8+16=31,最小的n的倍数是40)。
使用数组实现,空间复杂度为O(n)。
三、线段树的操作
线段树的操作分为3个步骤:
1、从给定的数组构建线段树
使用自下而上的方式构建线段树,先创建孩子节点,根据业务逻辑(题目需求,如求区间和、最大值等)创建两个孩子的根节点。
2、修改元素,更新线段树
如果更新的元素在左半区间,就递归更新左孩子;如果在右半区间,同理。更新之后,里面的小区间值发生变化了,外面的大区间节点也要根据业务逻辑(题目需求)进行更新。
3、使用线段树对区间进行检索
查询的区间可能在左半区间(或是右半区间),那就递归查询左孩子(右孩子)。如果一部分在左边,一部分在右边,则递归查询完左区间和右区间之后,再根据业务逻辑(题目需求)进行合并处理。
四、多种实现
业务逻辑是实现区间和
实现1: 开发merge接口传入业务逻辑
Python实现
class SegmentTree:
def __init__(self, data, merge):
'''
data:传入的数组
merge:处理的业务逻辑,例如求和/最大值/最小值,lambda表达式
'''
self.data = data
self.n = len(data)
# 申请4倍data长度的空间来存线段树节点
self.tree = [None] * (4 * self.n) # 索引i的左孩子索引为2i+1,右孩子为2i+2
self._merge = merge
if self.n:
self._build(0, 0, self.n-1)
def query(self, ql, qr):
'''
返回区间[ql,..,qr]的值
'''
return self._query(0, 0, self.n-1, ql, qr)
def update(self, index, value):
# 将data数组index位置的值更新为value,然后递归更新线段树中被影响的各节点的值
self.data[index] = value
self._update(0, 0, self.n-1, index)
def _build(self, tree_index, l, r):
'''
递归创建线段树
tree_index : 线段树节点在数组中位置
l, r : 该节点表示的区间的左,右边界
'''
if l == r:
self.tree[tree_index] = self.data[l]
return
mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
left, right = 2 * tree_index + 1, 2 * tree_index + 2 # tree_index的左右子树索引
self._build(left, l, mid)
self._build(right, mid+1, r)
self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])
def _query(self, tree_index, l, r, ql, qr):
'''
递归查询区间[ql,..,qr]的值
tree_index : 某个根节点的索引
l, r : 该节点表示的区间的左右边界
ql, qr: 待查询区间的左右边界
'''
if l == ql and r == qr:
return self.tree[tree_index]
mid = (l+r) // 2 # 区间中点,对应左孩子区间结束,右孩子区间开头
left, right = tree_index * 2 + 1, tree_index * 2 + 2
if qr <= mid:
# 查询区间全在左子树
return self._query(left, l, mid, ql, qr)
elif ql > mid:
# 查询区间全在右子树
return self._query(right, mid+1, r, ql, qr)
# 查询区间一部分在左子树一部分在右子树
return self._merge(self._query(left, l, mid, ql, mid),
self._query(right, mid+1, r, mid+1, qr))
def _update(self, tree_index, l, r, index):
'''
tree_index:某个根节点索引
l, r : 此根节点代表区间的左右边界
index : 更新的值的索引
'''
if l == r == index:
self.tree[tree_index] = self.data[index]
return
mid = (l+r)//2
left, right = 2 * tree_index + 1, 2 * tree_index + 2
if index > mid:
# 要更新的区间在右子树
self._update(right, mid+1, r, index)
else:
# 要更新的区间在左子树index<=mid
self._update(left, l, mid, index)
# 里面的小区间变化了,包裹的大区间也要更新
self.tree[tree_index] = self._merge(self.tree[left], self.tree[right])
class NumArray:
def __init__(self, nums: List[int]):
self.segment_tree = SegmentTree(nums, lambda x, y : x + y)
def update(self, i: int, val: int) -> None:
self.segment_tree.update(i, val)
def sumRange(self, i: int, j: int) -> int:
return self.segment_tree.query(i, j)
结果
Java实现
class NumArray {
private interface Merger<E> {
E merge(E a, E b);
}
private class SegmentTree<E> {
private E[] tree;
private E[] data;
private Merger<E> merger;
public SegmentTree(E[] arr, Merger<E> merger){
this.merger = merger;
data = (E[])new Object[arr.length];
for(int i = 0 ; i < arr.length ; i ++)
data[i] = arr[i];
tree = (E[])new Object[4 * arr.length];
buildSegmentTree(0, 0, arr.length - 1);
}
// 在treeIndex的位置创建表示区间[l...r]的线段树
private void buildSegmentTree(int treeIndex, int l, int r){
if(l == r){
tree[treeIndex] = data[l];
return;
}
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
// int mid = (l + r) / 2;
int mid = l + (r - l) / 2;
buildSegmentTree(leftTreeIndex, l, mid);
buildSegmentTree(rightTreeIndex, mid + 1, r);
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
public int getSize(){
return data.length;
}
public E get(int index){
if(index < 0 || index >= data.length)
throw new IllegalArgumentException("Index is illegal.");
return data[index];
}
// 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
private int leftChild(int index){
return 2*index + 1;
}
// 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
private int rightChild(int index){
return 2*index + 2;
}
// 返回区间[queryL, queryR]的值
public E query(int queryL, int queryR){
if(queryL < 0 || queryL >= data.length ||
queryR < 0 || queryR >= data.length || queryL > queryR)
throw new IllegalArgumentException("Index is illegal.");
return query(0, 0, data.length - 1, queryL, queryR);
}
// 在以treeIndex为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值
private E query(int treeIndex, int l, int r, int queryL, int queryR){
if(l == queryL && r == queryR)
return tree[treeIndex];
int mid = l + (r - l) / 2;
// treeIndex的节点分为[l...mid]和[mid+1...r]两部分
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if(queryL >= mid + 1)
return query(rightTreeIndex, mid + 1, r, queryL, queryR);
else if(queryR <= mid)
return query(leftTreeIndex, l, mid, queryL, queryR);
E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
return merger.merge(leftResult, rightResult);
}
// 将index位置的值,更新为e
public void set(int index, E e){
if(index < 0 || index >= data.length)
throw new IllegalArgumentException("Index is illegal");
data[index] = e;
set(0, 0, data.length - 1, index, e);
}
// 在以treeIndex为根的线段树中更新index的值为e
private void set(int treeIndex, int l, int r, int index, E e){
if(l == r){
tree[treeIndex] = e;
return;
}
int mid = l + (r - l) / 2;
// treeIndex的节点分为[l...mid]和[mid+1...r]两部分
int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if(index >= mid + 1)
set(rightTreeIndex, mid + 1, r, index, e);
else // index <= mid
set(leftTreeIndex, l, mid, index, e);
tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}
@Override
public String toString(){
StringBuilder res = new StringBuilder();
res.append('[');
for(int i = 0 ; i < tree.length ; i ++){
if(tree[i] != null)
res.append(tree[i]);
else
res.append("null");
if(i != tree.length - 1)
res.append(", ");
}
res.append(']');
return res.toString();
}
}
private SegmentTree<Integer> segTree;
public NumArray(int[] nums) {
if(nums.length != 0){
Integer[] data = new Integer[nums.length];
for(int i = 0 ; i < nums.length ; i ++)
data[i] = nums[i];
segTree = new SegmentTree<>(data, (a, b) -> a + b);
}
}
public void update(int i, int val) {
if(segTree == null)
throw new IllegalArgumentException("Error");
segTree.set(i, val);
}
public int sumRange(int i, int j) {
if(segTree == null)
throw new IllegalArgumentException("Error");
return segTree.query(i, j);
}
}
/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(i,val);
* int param_2 = obj.sumRange(i,j);
*/
结果
实现2: 二叉树实现,非数组
public class NumArray {
class SegmentTreeNode {
int start, end;
SegmentTreeNode left, right;
int sum;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
this.left = null;
this.right = null;
this.sum = 0;
}
}
SegmentTreeNode root = null;
public NumArray(int[] nums) {
root = buildTree(nums, 0, nums.length-1);
}
private SegmentTreeNode buildTree(int[] nums, int start, int end) {
if (start > end) {
return null;
} else {
SegmentTreeNode ret = new SegmentTreeNode(start, end);
if (start == end) {
ret.sum = nums[start];
} else {
int mid = start + (end - start) / 2;
ret.left = buildTree(nums, start, mid);
ret.right = buildTree(nums, mid + 1, end);
ret.sum = ret.left.sum + ret.right.sum;
}
return ret;
}
}
void update(int i, int val) {
update(root, i, val);
}
void update(SegmentTreeNode root, int pos, int val) {
if (root.start == root.end) {
root.sum = val;
} else {
int mid = root.start + (root.end - root.start) / 2;
if (pos <= mid) {
update(root.left, pos, val);
} else {
update(root.right, pos, val);
}
root.sum = root.left.sum + root.right.sum;
}
}
public int sumRange(int i, int j) {
return sumRange(root, i, j);
}
public int sumRange(SegmentTreeNode root, int start, int end) {
if (root.end == end && root.start == start) {
return root.sum;
} else {
int mid = root.start + (root.end - root.start) / 2;
if (end <= mid) {
return sumRange(root.left, start, end);
} else if (start >= mid+1) {
return sumRange(root.right, start, end);
} else {
return sumRange(root.right, mid+1, end) + sumRange(root.left, start, mid);
}
}
}
}
结果
实现3:
树的构造和前面的不一样
注意:左孩子为2i,右孩子为2i+1,初始申请数组为2n
Java实现
public class NumArray {
int[] tree;
int n;
public NumArray(int[] nums) {
if (nums.length > 0) {
n = nums.length;
tree = new int[n * 2];
buildTree(nums);
}
}
private void buildTree(int[] nums) {
// 赋值给叶子结点
for (int i = n, j = 0; i < n * 2; i++, j++) {
tree[i] = nums[j];
}
for (int i = n - 1; i > 0; i--) {
tree[i] = tree[i * 2] + tree[i * 2 + 1];
}
}
// 自下而上,先更新叶子结点,然后一直向上,一直到根节点
void update(int pos, int val) {
pos += n; // 叶子结点
tree[pos] = val;
while (pos > 0) {
int left = pos;
int right = pos;
// left为偶数,right为奇数
if (pos % 2 == 0) {
right = pos + 1;
} else {
left = pos - 1;
}
// parent
tree[pos / 2] = tree[left] + tree[right];
pos /= 2;
}
}
public int sumRange(int l, int r) {
// get leaf with value 'l'
l += n;
// get leaf with value 'r'
r += n;
int sum = 0;
while (l <= r) {
if ((l % 2) == 1) { // l为奇数
sum += tree[l];
l++;
}
if ((r % 2) == 0) { // r为偶数
sum += tree[r];
r--;
}
l /= 2;
r /= 2;
}
return sum;
}
}
复杂度分析
构建线段树:时间复杂度为O(n),空间也是
更新:时间复杂度为O(logn),空间是O(1)
查询:时间复杂度为O(logn),空间是O(1)
结果
参考
1、理论讲解
2、优秀题解
https://leetcode-cn.com/problems/range-sum-query-mutable/solution/xian-duan-shu-segmenttree-shu-zu-shi-xian-by-zhou-/