dsu on tree
介绍
dsu on tree是一种用来解决子树问题的算法,比如问你子树中节点权值位于[l,r]的个数;
还有比如问你子树中颜色出现次数最多的是那种颜色.这里介绍的会简略一些
思想
求出树的所有重儿子和轻儿子后:
1.遍历轻儿子
2.遍历重儿子
3.统计亲儿子的贡献并加到重儿子上
4.统计答案并且删除轻儿子的贡献
解释
1.删除轻儿子的贡献是为了避免遍历父亲的时候轻儿子多算了一次
2.删轻儿子的贡献会留下重儿子的贡献,于是之后遍历父节点的时候重儿子的贡献就不用计算了
3.复杂度的证明:有点类并查集的按秩合并,每个点最终都会被合并到重儿子上,合并的次数大概为logn次
所以算法的时间复杂度大概为nlogn
例题
CF600E
代码如下
//dsu on tree
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
int h[N],e[N*2],ne[N*2],idx;
void add(int a, int b) // 添加一条边a->b
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
int sz[N],son[N];
void dfs(int u,int f)
{
sz[u]=1;
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==f) continue;
dfs(j,u);
sz[u]+=sz[j];
if(sz[son[u]]<sz[j]) son[u]=j;
}
}
int col[N],cnt[N];
ll ans[N],sum;
int flag,maxn;
void count(int u,int f,int val)
{
cnt[col[u]]+=val;
if(cnt[col[u]]>maxn)
sum=col[u],maxn=cnt[col[u]];
else if(cnt[col[u]]==maxn) sum+=col[u];
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==f||j==flag) continue;
count(j,u,val);
}
}
void dfs(int u,int f,int keep)
{
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==f||j==son[u]) continue;
dfs(j,u,false);
}
if(son[u])
{
dfs(son[u],u,true);
flag=son[u];
}
count(u,f,1);
flag=0;
ans[u]=sum;
if(!keep)
{
count(u,f,-1);
sum=maxn=0;
}
}
int main()
{
int n;scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&col[i]);
memset(h,-1,sizeof h);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs(1,-1);
dfs(1,0,0);
for(int i=1;i<=n;i++)
printf("%lld ",ans[i]);
return 0;
}
我的模板
void dfs(int u,int f,int keep)
{
for(int i=h[u];i!=-1;i=ne[i])
{
int j=e[i];
if(j==f||j==son[u]) continue;
dfs(j,u,false);
}
if(son[u])
{
dfs(son[u],u,true);
flag=son[u];
}
count();//统计轻儿子的贡献
flag=0;
if(!keep)
{
count();//删除轻儿子的贡献
}
}
多校9E
代码如下:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
using ll = long long;
constexpr int N = 1e5 + 5;
int n, dn = 3e5;
int T[N], c[N * 3], son[N], siz[N];
int ans[N], l[N], r[N], tt[N << 2];
int fa[N][20],flag;
bool go[N];
vector<int> vt[N], id[N];
int lowbit(int x) {
return x & -x;
}
void add(int x, int v) {
while(x <= dn) {
c[x] += v;
x += lowbit(x);
}
}
int ask(int x) {
int ret = 0;
while(x) {
ret += c[x];
x -= lowbit(x);
}
return ret;
}
void dfs1(int u,int f)
{
fa[u][0]=f;
for(int i=1;i<=19;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
siz[u]=1;
for(auto v:vt[u])
{
if(v==f) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
int fd(int u,int up)
{
for(int i=19;i>=0;i--)
if(T[fa[u][i]]<=up) u=fa[u][i];
return u;
}
void count(int u,int f,int val)
{
add(T[u],val);
for(auto v:vt[u])
{
if(v==f||v==flag) continue;
count(v,u,val);
}
}
void dfs2(int u,int f,bool keep)
{
for(auto v:vt[u])
{
if(v==f||v==son[u]) continue;
dfs2(v,u,false);
}
if(son[u])
{
dfs2(son[u],u,true);
flag=son[u];
}
count(u,f,1);
if(go[u] == 0)
for(int v : id[u]) {
// if(u == 5) {
// printf("hhh %d %d %d\n", l[v], r[v], ask(r[v]));
// }
go[u] = 1;
int L = l[v], R = r[v];
ans[v] = ask(R) - ask(L - 1);
}
flag=0;
if(!keep)
count(u,f,-1);
}
int main() {
scanf("%d", &n);
for(int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
vt[u].push_back(v);
vt[v].push_back(u);
}
int cnt = 0;
for(int i = 1; i <= n; ++i) scanf("%d", &T[i]), tt[++cnt] = T[i];
dfs1(1, 1);
int q, x;
scanf("%d", &q);
for(int i = 1; i <= q; ++i) {
scanf("%d%d%d", &x, &l[i], &r[i]);
tt[++cnt] = l[i];
tt[++cnt] = r[i];
if(T[x] < l[i] || T[x] > r[i]) continue;
int gf = fd(x, r[i]);
// printf("hh %d\n", gf);
id[gf].push_back(i);
}
sort(tt + 1, tt + 1 + cnt);
int len = unique(tt + 1, tt + 1 + cnt) - tt - 1; dn = len;
for(int i = 1; i <= n; ++i) T[i] = lower_bound(tt + 1, tt + len + 1, T[i]) - tt;
for(int i = 1; i <= q; ++i) {
l[i] = lower_bound(tt + 1, tt + 1 + len, l[i]) - tt;
r[i] = lower_bound(tt + 1, tt + 1 + len, r[i]) - tt;
}
dfs2(1, 0, true);
for(int i = 1; i <= q; ++i) printf("%d\n", ans[i]);
}