题目思路:
线段树+树链剖分,处理好线段树合并的部分,注意在进行树链求和操作时,对于多部分链的端点需要进行合并判断,如果颜色相同,合并后的答案等于ans-1,反则等于ans
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll maxn=1e5+5;
ll head[maxn],cnt,top[maxn],son[maxn],deep[maxn],pre[maxn];
ll color[maxn],sizx[maxn],dfn[maxn],cnx,w[maxn];
struct node
{
ll to,nex;
} edge[maxn<<3];
void add(ll u,ll v)
{
edge[cnt].to=v;
edge[cnt].nex=head[u];
head[u]=cnt++;
}
void dfs1(ll u,ll fa)
{
pre[u]=fa;
deep[u]=deep[fa]+1;
ll maxson=-1;
sizx[u]=1;
for(ll i=head[u]; ~i; i=edge[i].nex)
{
ll v=edge[i].to;
if(v==fa)
continue;
dfs1(v,u);
sizx[u]+=sizx[v];
if(maxson<sizx[v])
{
maxson=sizx[v];
son[u]=v;
}
}
}
void dfs2(ll u,ll t)
{
top[u]=t;
dfn[u]=++cnx;
w[cnx]=color[u];
if(!son[u])
return ;
dfs2(son[u],t);
for(ll i=head[u]; ~i; i=edge[i].nex)
{
ll v=edge[i].to;
if(v==pre[u]||v==son[u])
{
continue;
}
dfs2(v,v);
}
}
struct vain
{
ll l,r;
ll sum,lazy,lc,rc;
} tr[maxn<<2];
void pushup(ll k)
{
ll sum=tr[k<<1].sum+tr[k<<1|1].sum;
if(tr[k<<1].rc==tr[k<<1|1].lc)
tr[k].sum=sum-1;
else
tr[k].sum=sum;
tr[k].lc=tr[k<<1].lc;
tr[k].rc=tr[k<<1|1].rc;
}
void pushdown(ll k)
{
if(tr[k].lazy)
{
tr[k<<1].lc=tr[k<<1].rc=tr[k].lazy;
tr[k<<1|1].lc=tr[k<<1|1].rc=tr[k].lazy;
tr[k<<1].lazy=tr[k<<1|1].lazy=tr[k].lazy;
tr[k<<1].sum=tr[k<<1|1].sum=1;
tr[k].lazy=0;
}
}
void build(ll k,ll l,ll r)
{
tr[k].l=l,tr[k].r=r;
if(l==r)
{
tr[k].lc=tr[k].rc=w[l];
tr[k].sum=1;
return ;
}
ll mid=l+r>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);
}
void modify(ll k,ll l,ll r,ll c)
{
if(tr[k].l>=l&&tr[k].r<=r)
{
tr[k].lazy=c;
tr[k].lc=tr[k].rc=c;
tr[k].sum=1;
return ;
}
pushdown(k);
ll mid=tr[k].l+tr[k].r>>1;
if(mid>=r)
modify(k<<1,l,r,c);
else if(mid<l)
modify(k<<1|1,l,r,c);
else
modify(k<<1,l,mid,c),modify(k<<1|1,mid+1,r,c);
pushup(k);
}
ll qid(ll k,ll x)
{
if(tr[k].l==tr[k].r)
return tr[k].lc;
pushdown(k);
ll mid=tr[k].l+tr[k].r>>1;
if(mid>=x)
return qid(k<<1,x);
else
return qid(k<<1|1,x);
pushup(k);
}
ll ask(ll k,ll l,ll r)
{
if(tr[k].l>=l&&tr[k].r<=r)
return tr[k].sum;
pushdown(k);
ll mid=tr[k].l+tr[k].r>>1;
if(mid>=r)
return ask(k<<1,l,r);
else if(mid<l)
return ask(k<<1|1,l,r);
else
return (ask(k<<1,l,mid)+ask(k<<1|1,mid+1,r))-((tr[k<<1].rc==tr[k<<1|1].lc)?1:0);
}
void opt1(ll x,ll y,ll c)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
swap(x,y);
modify(1,dfn[top[x]],dfn[x],c);
x=pre[top[x]];
}
if(deep[x]>deep[y])
swap(x,y);
modify(1,dfn[x],dfn[y],c);
}
ll opt2(ll x,ll y)
{
ll res=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])
swap(x,y);
res+=ask(1,dfn[top[x]],dfn[x]);
if(qid(1,dfn[top[x]])==qid(1,dfn[pre[top[x]]]))
res--;
x=pre[top[x]];
}
if(deep[x]>deep[y])
swap(x,y);
res+=ask(1,dfn[x],dfn[y]);
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
memset(head,-1,sizeof(head));
ll n,q;
cin>>n>>q;
for(ll i=1; i<=n; i++)
{
cin>>color[i];
}
ll u,v;
for(ll i=1; i<n; i++)
{
cin>>u>>v,add(u,v),add(v,u);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
while(q--)
{
char s;
ll x,y,z;
cin>>s;
if(s=='Q')
{
cin>>x>>y;
cout<<opt2(x,y)<<endl;
}
else
{
cin>>x>>y>>z;
opt1(x,y,z);
}
}
}