Advise Category: Algorithm >> 树状数组
Scenario
- 单点更新
- 区间求和(前缀和)
Objects
- 待处理数组,a[1...n]
- 待维护树状数组,c[1...n]
- 结果数组(前缀和)(待处理数组的前缀和),r[1...n]
Ideas
Q:
A1:
每次计算r[i]
都遍历a
的前i
项,一个结果r[i]
一次遍历,n
个结果r[n]
需要n
次遍历(不适合大数据数组的情况)。
A2:(树状)
?如何加速遍历
不与a[i]
为维度进行运算,而是通过维护一个数组(数组)c[1...n]
其中每一项是通过一定规律的m
项和,再通过规律找到前缀和r[i]
的待累积c[i]
并求和。
Method for A2
a[i] -> c[i]:
c[1] = c[0001] = a[1];
c[2] = c[0010] = a[1]+a[2];
c[3] = c[0011] = a[3];
c[4] = c[0100] = a[1]+a[2]+a[3]+a[4];
c[5] = c[0101] = a[5];
c[6] = c[0110] = a[5]+a[6];
c[7] = c[0111] = a[7];
c[8] = c[1000] = a[1]+a[2]+a[3]+a[4]+a[5]+a[6]+a[7]+a[8];
......
Rule:
c[i]=a[i-2^k+1]+a[i-2^k+2]+......a[i];
Note:
-
k
为i
的二进制中从最低位到高位连续零的长度, 例如i=8(1000)
时,k=3
。 - 可以理解为这是一种分类方法,通过维护分类(要么有零,要么没零,有零说明进位了需要把其下的数都考虑在内)数组
c[i]
让前缀求和变得更快。
单点更新
?如果修改a
中的一个元素,c
中的元素如何变化
Assumption:a[3] = a[3] + 1
,从3往后找,直到数组结束。
lowbit(3)=0001=2^0 3+lowbit(3)=04(00100) c[04] += 1
lowbit(4)=0100=2^2 4+lowbit(4)=08(01000) c[08] += 1
lowbit(8)=1000=2^3 8+lowbit(8)=16(10000) c[16] += 1
......
Note:
- 可以看出a[3]变化之后,会涉及到c[4]/c[8]/[16]...的变化,所以需要更新跟随变化的c中的元素。
?lowbit
lowbit(x)是取出x的最低位1(从右往左数第一个1),满足:
int lowbit(x){ return x & (-x); }
Note:
一个数的负数就等于对这个数取反+1
补码和原码必然相反,所以原码有0的部位补码全是1,补码再+1之后由于进位那么最末尾的1和原码最右边的1一定是同一个位置
刚好等于
2^k
,k
为x
的二进制中从最低位到高位连续零的长度
Code:
void update(int x,int y,int n){
for(int i=x; i <= n; i += lowbit(i)) //x为更新的位置,y为更新后的数,n为数组最大值
c[i] += y;
}
区间求和
e.g 求r[5]
Disappear:
c[4]=a[1]+a[2]+a[3]+a[4];
c[5]=a[5];
sum(i = 5) = c[4] + c[5];
sum(i = 101) = c[(100)] + c[(101)];
Note:
- 首次从101,减去最低位的1就是100,刚好是单点更新的你操作。
Code:
int sum(int x){
int ans = 0;
for(int i = x; i >= 0; i -= lowbit(i))
ans += c[i];
return ans;
}
Example
leet-code:
Ans:
class Solution{
public List<Integer> countSmaller3(int[] nums){
if(nums == null) return null;
if(nums.length == 0) return new ArrayList<>();
List<Integer> result = new ArrayList<>();
Set<Integer> set = new HashSet<Integer>();
for (int i = 0; i < nums.length; i++) {
set.add(nums[i]);
}
int[] c = Arrays.stream(set.toArray()).sorted().mapToInt(e -> Integer.parseInt(e.toString())).toArray();
int[] d = new int[c.length + 2];
for (int i = nums.length - 1; i > -1; i--) {
int idx = Arrays.binarySearch(c, nums[i]) + 1;
int s = sum(d, idx - 1);
update(d, idx, 1);
result.add(s);
}
Collections.reverse(result);
return result;
}
private int lowbit(int x){
return x & (-x);
}
private void update(int[] c, int i, int delta){
while( i <= c.length - 1 ){
c[i] += delta;
i += lowbit(i);
}
}
private int sum(int[] c, int i){
int ans = 0;
while( i > 0 ){
ans += c[i];
i -= lowbit(i);
}
return ans;
}
}