这是一篇教学,同样是我的学习笔记,第一篇,较为认真。
本文保证符合洛谷题解审核制度,如有错误,欢迎指正。
其中的错误,包括但不限于格式错误,图片爆炸,内容错误。
本文少量使用 texttt
,如果看着不习惯,作者感到很抱歉。如果这是不符合题解审核规范的,请及时告知。
本文以一个区间加区间查询和的线段树为例。
首先有如下宏定义。
#define now tr[u]
#define ls tr[u<<1]
#define rs tr[u<<1|1]
定义节点 $u$ 的左儿子是 $2u$,右儿子是 $2u+1$。
线段树的每一个节点代表了一个区间,我们需要做的就是动态维护这些区间。
线段树令左儿子与右儿子可以完全拼成父亲区间,这样可以让每一个区间可以表示为 $O(\log n)$ 个区间。(好好想想为什么这是对的)
定义节点为如下格式。
struct Node{
int l,r;
int sum,add;//表示区间和与懒惰标记(这个等一会儿会讲)
}tr[N<<2];
线段树需要开四倍空间。
在修改了子区间的信息后,势必会影响父亲区间的内容,因此修改后需要进行 pushup。
void pushup(int u)
{
now.sum=ls.sum+rs.sum;
}
我们在修改了一个区间后,会修改许许多多的子区间,但是如果这些子区间没有被用到,岂不是不用修改了?
这就是懒惰标记的思想了,在使用时,记得先下推标记哦。
pushdown 代码如下。
void pushdown(int u)
{
if(now.add)
{
ls.sum+=(ls.r-ls.l+1)*now.add;
rs.sum+=(rs.r-rs.l+1)*now.add;
ls.add+=now.add;
rs.add+=now.add;
now.add=0;
}
}
接下来就是建树了,我们将区间 $[l,r]$ 分为 $[l,mid]$ 与 $[mid+1,r]$ 并递归处理,当区间非法就停止处理。
建树代码如下。
void build(int u,int l,int r)
{
if(l==r) tr[u]={l,r,a[l],0};//平凡的区间
else
{
tr[u]={l,r,0,0};
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);//递归处理
pushup(u);//记得 pushup
}
}
接下来就是修改了。
我们考虑修改的时候分两种情况。
如果修改区间包含了本区间,直接修改本区间并加上懒标记。
否则,在其左子区间和右子区间内递归处理。
修改代码如下。
void update(int u,int l,int r,int d)
{
if(now.l>=l&&now.r<=r)
{
now.sum+=(now.r-now.l+1)*d;//修改本区间
now.add+=d;//加上懒标记
}
else
{
pushdown(u);//下推标记
int mid=now.l+now.r>>1;
if(l<=mid) update(u<<1,l,r,d);//如果左区间与修改区间有相交就递归
if(r>mid) update(u<<1|1,l,r,d);//同上理
pushup(u);//记得 pushup
}
}
查询类似修改,记得下推标记。
修改代码如下。
int query(int u,int l,int r)
{
if(now.l>=l&&now.r<=r) return now.sum;//这是平凡的
else
{
pushdown(u);//下推标记
int mid=now.l+now.r>>1,res=0;
if(l<=mid) res=query(u<<1,l,r);//如果左区间与修改区间有相交就递归
if(r>mid) res+=query(u<<1|1,l,r);//同上理
return res;
}
}
整体代码如下。
// Problem: P3372 【模板】线段树 1
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3372
// Memory Limit: 125 MB
// Time Limit: 1000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define pb push_back
#define MT int TTT=R;while(TTT--)
#define pc putchar
#define R read()
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define rep(i,a,b) for(int i=a;i>=b;i--)
#define m1(a,b) memset(a,b,sizeof a)
namespace IO
{
inline int read()
{
int x=0;
char ch=getchar();
bool f=0;
while(!isdigit(ch)){if(ch=='-') f=1;ch=getchar();}
while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
if(f) x=-x;
return x;
}
template<typename T> inline void write(T x)
{
if(x<0)
{
pc('-');
x=-x;
}
if(x>9) write(x/10);
pc(x%10+'0');
}
};
namespace math
{
inline int gcd(int a,int b)
{
return b?gcd(b,a%b):a;
}
inline int qmi(int a,int b,int p)
{
int res=1;
while(b)
{
if(b&1) res=res*a%p;
a=a*a%p;
b>>=1;
}
return res;
}
inline int inv(int a,int p)
{
return qmi(a,p-2,p);
}
const int MAXN=2e6+10;
int my_fac[MAXN],my_inv[MAXN];
void init_binom(int mod)
{
my_fac[0]=1;fo(i,1,min(MAXN,mod)-1) my_fac[i]=my_fac[i-1]*i%mod;
my_inv[min(MAXN,mod)-1]=qmi(my_fac[min(MAXN,mod)-1],mod-2,mod);rep(i,min(MAXN,mod)-2,0) my_inv[i]=my_inv[i+1]*(i+1)%mod;
}
int binom(int a,int b,int mod)
{
return my_fac[a]*my_inv[b]%mod*my_inv[a-b]%mod;
}
}
using namespace IO;
using namespace math;
#define now tr[u]
#define ls tr[u<<1]
#define rs tr[u<<1|1]
const int N=1e5+10;
int n,m;
int a[N];
struct Node{
int l,r;
int sum,add;
}tr[N<<2];
void pushup(int u)
{
now.sum=ls.sum+rs.sum;
}
void pushdown(int u)
{
if(now.add)
{
ls.sum+=(ls.r-ls.l+1)*now.add;
rs.sum+=(rs.r-rs.l+1)*now.add;
ls.add+=now.add;
rs.add+=now.add;
now.add=0;
}
}
void build(int u,int l,int r)
{
if(l==r) tr[u]={l,r,a[l],0};
else
{
tr[u]={l,r};
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
}
void update(int u,int l,int r,int d)
{
if(now.l>=l&&now.r<=r)
{
now.sum+=(now.r-now.l+1)*d;
now.add+=d;
}
else
{
pushdown(u);
int mid=now.l+now.r>>1;
if(l<=mid) update(u<<1,l,r,d);
if(r>mid) update(u<<1|1,l,r,d);
pushup(u);
}
}
int query(int u,int l,int r)
{
if(now.l>=l&&now.r<=r) return now.sum;
else
{
pushdown(u);
int mid=now.l+now.r>>1,res=0;
if(l<=mid) res=query(u<<1,l,r);
if(r>mid) res+=query(u<<1|1,l,r);
return res;
}
}
signed main(){
n=R,m=R;
fo(i,1,n) a[i]=R;
build(1,1,n);
while(m--)
{
int opt=R,x=R,y=R;
if(opt==1)
{
int k=R;
update(1,x,y,k);
}
else write(query(1,x,y)),puts("");
}
}
后记
建议读者自己手动敲几遍加深记忆。
当然根据退役选手 sto_5k_orz 的说法,BIT 其实是几乎万能的。
查理线段树
令
#define ls tr[mid<<1]
#define rs tr[mid<<1|1]
只需要两倍空间即可。
但是你可能要牺牲一些 cache 问题。
大佬tql