理论:线段树+区段更新+lazy思想
思路:看到题目中说要更新数组中一个区段的数,还有查询求和就可以往线段树的思路想了,这题比较特殊在于,它更新数组中一个区段的数时,更新的不是常数,而是更新一个从u到v中 所有数都要+/-(i+1)(i+2)(i+3)的一个递增数列。
我们可以发现
所以和寻常的线段树中的区段更新不同的是,这回我们要把lazy数用一个数组存储,然后再通过循环得到我们的答案。
其他的就是线段树的模版了。
AC代码:
//
// Created by xiaozhang on 2019/5/14.
//
#include <sstream>
#include <fstream>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iostream>
#include <iomanip>
#include <algorithm>
#include <vector>
#include <set>
#include <stack>
#include <map>
#include <string>
#include <queue>
using namespace std;
#define ll long long
#define max 200005
const ll mod=1000000007;
int n,m;
struct node{
ll sum,lazy[4];
}tree[max*4];
ll sum[4][max];
ll find_val(int k,ll x)
{
if(k == 0) return (-1LL*x*x*x + 6*x*x - 11*x + 6) % mod;
if(k == 1) return (3*x*x - 12*x + 11) % mod;
if(k == 2) return -3*x + 6;
return 1;
}
void add(ll &a,ll b)
{
a=(a+b)%mod;
if(a<0)
a+=mod;
}
void pushdown(int i,int l, int r) //自顶向下更新lazy数组和给结点加上lazy数组的值
{
int mid=(l+r)/2;
int lson=i*2;
int rson=i*2+1;
for(int k=0;k<=3;k++)
{
add(tree[lson].lazy[k],tree[i].lazy[k]); //给左右孩子传递lazy
add(tree[rson].lazy[k],tree[i].lazy[k]);
add(tree[lson].sum, tree[i].lazy[k]*(sum[k][mid] - sum[k][l-1] + mod));
add(tree[rson].sum, tree[i].lazy[k]*(sum[k][r] - sum[k][mid] + mod));
tree[i].lazy[k]=0; //把父节点的lazy归0
}
}
void update(int i,int l, int r, int u, int v, int flag)
{
if(l>v||r<u)return; //若查询区间不在当前区间内,直接return
if(u<=l&&r<=v){
for(int k=0;k<=3;k++)
{
ll val=find_val(k,u)*flag;
add(tree[i].sum,val*(sum[k][r]-sum[k][l-1]+mod));
add(tree[i].lazy[k],val);
}
return;
}
int mid=(l+r)/2;
pushdown(i,l,r);
update(i<<1,l,mid,u,v,flag); //更新左子树
update((i<<1)|1,mid+1,r,u,v,flag); //更新右子树
tree[i].sum=(tree[i<<1].sum+tree[(i<<1)|1].sum)%mod; //自顶向上更新区间和
}
ll query(int i, int l, int r,int u, int v)
{
if (l > v || r < u) return 0;//查询结点和区间没有公共点
if (u <= l && r <= v) return tree[i].sum;//查询区间包含查询结点
int mid=(l+r)/2;
pushdown(i,l,r);
return (query(i*2, l, mid, u, v) + query(i*2+1, mid+1, r, u, v)) % mod;
}
int main()
{
cin>>n>>m;
for(int i=1;i<=n;i++)
{
sum[0][i] = sum[0][i-1] + 1;
sum[1][i] = (sum[1][i-1] + i) % mod;
sum[2][i] = (sum[2][i-1] + 1LL*i*i) % mod;
sum[3][i] = (sum[3][i-1] + 1LL*i*i*i) % mod;
}
int sign,u,v;
while(m--)
{
cin>>sign>>u>>v;
if(sign==1)
{
update(1,1,n,u,v,+1);
}
else if(sign==2)
{
update(1,1,n,u,v,-1);
}
else if(sign==0)
{
ll ans=query(1,1,n,u,v);
cout<<ans<<endl;
}
}
return 0;
}