树上计数问题2
解题思路
本题有若干个询问,每个询问会问我们树上某一段路径中不同权值的个数。
本题如果将树换成一个序列,那么就是一个经典的基础莫队问题,现在我们等于是要在树上做莫队算法,这就是一个树上莫队问题。
这里有一个通用的做法能将树上的问题统一变成区间中的问题。我们将这棵树按照欧拉序列的形式写成一个序列。欧拉序列就是按照深度优先遍历的方式遍历整个树,每个节点在进入和出去的时候都需要加入序列一次。因此欧拉序列中每个节点应该出现两次。
然后我们就能将树上的任意一段路径 $(x, y)$ 对应到欧拉序列中的任意一段区间。
设 $first[x]$ 表示 $x$ 第一次出现的位置,$last[x]$ 表示 $x$ 最后一次出现的位置。并令 $(x, y)$ 满足 $first[x] < first[y]$,即 $x$ 比 $y$ 先遍历到。
此时如果 $lca(x, y) == x$,说明这段路径是从 $x$ 往下遍历到 $y$,可以发现此时这段路径上的所有点都是往下搜索的时候发现的,并且还没有进行回溯,因此此时 $(x, y)$ 之间的路径就能对应欧拉路径中 $[first[x], first[y]]$ 中只出现一次的点,因为出现两次的点就说明进入这个点所在的分支后又出来了,说明这个点是不属于我们要找的路径中的。
而如果 $lca(x, y) != x$,说明这段路径是由 $x$ 往上回溯到 $lca(x, y)$ 的路径和从 $lca(x, y)$ 往下搜索到 $y$ 的路径组成,因此此时 $(x, y)$ 之间的路径就能对应欧拉路径中 $[last[x], first[y]]$ 中只出现一次的点以及 $lca(x, y)$,因为 $x$ 和 $y$ 都是 $lca(x, y)$ 的子节点,因此从 $x$ 搜索到 $y$ 为止 $lca(x, y)$ 都没有进行回溯,因此 $lca(x, y)$ 是不会出现在 $[last[x], first[y]]$ 中的,所以需要手动加上。
通过以上的转化,每一个树上询问就变成了求某一段区间中只出现一次的数中不同权值的数量。在没有只出现一次的限制时,我们可以用一个 $cnt$ 数组记录每个数出现的次数,用一个 $res$ 记录当前 $cnt$ 数组中不同数的个数,如果加上只出现一次的限制,相当于将出现两次的数也看作没有出现过,因此我们可以额外用一个 $st$ 数组来记录每个节点是否只出现过一次,$1$ 表示出现一次,$0$ 表示出现 $0$ 次或 $2$ 次。而我们在执行莫队算法的过程中,一个节点 $x$ 如果出现一次,$st[x]$ 就从 $0$ 变成 $1$,如果出现两次,就又从 $1$ 变成 $0$,其实就是异或运算,反过来也是一样,所以当我们加入或删除 $x$ 的时候,我们就令 $st[x]~xor~1$,如果 $st[x]$ 操作后变成 $1$,说明加入了这个数,要令 $cnt[w[x]]+1$,而如果 $cnt[w[x]]+1$ 后为 $1$,说明权值 $w[x]$ 第一个出现,对应的要令 $res+1$。如果 $st[x]$ 操作后变成 $0$,说明删除了这个数,要令 $cnt[w[x]]-1$,而如果 $cnt[w[x]]-1$ 后为 $0$,说明权值 $w[x]$ 全部移除了,要令 $res-1$。这样就能用 $O(1)$ 的时间加入或删除一个数,从而通过上一个查询的信息计算出当前查询的信息。
另外由于本题的权值范围在 $int$ 以内,因此需要对权值进行离散化。
以上就是本题的全部思路,简单来说就是将树上路径查询转化成区间查询,然后用基础莫队算法来实现。
C++ 代码
#include <iostream>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 40010, M = 100010;
struct Query
{
int id, l, r, p; //询问的编号、左端点、右端点、需要额外加入的数
}q[M];
int n, m, len; //节点个数、询问次数、块的长度
int h[N], e[N * 2], w[N * 2], ne[N * 2], idx; //邻接表
int que[N], d[N], fa[N][16]; //LCA 相关数组
int seq[N * 2], first[N], last[N], top; //欧拉序列、每个数第一次出现的位置、每个数最后一次出现的位置
int st[N], cnt[N]; //维护区间中每个节点是否只出现一次、维护区间中每个权值出现的次数
int res[M], id[N * 2]; //每个询问的答案、每个下标所在的块编号
vector<int> nums; //离散化
int find(int x) //返回每个数离散化后的结果
{
int l = 0, r = nums.size() - 1;
while(l < r)
{
int mid = l + r + 1 >> 1;
if(nums[mid] <= x) l = mid;
else r = mid - 1;
}
return r;
}
void add_edge(int a, int b) //添加边
{
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u, int fa) //深搜预处理 seq[], first[], last[]
{
seq[++top] = u;
first[u] = top;
for(int i = h[u]; i != -1; i = ne[i])
{
int j = e[i];
if(j == fa) continue;
dfs(j, u);
}
seq[++top] = u;
last[u] = top;
}
void bfs() //预处理 d[], fa[][]
{
memset(d, 0, sizeof d);
d[1] = 1;
int hh = 0, tt = 0;
que[0] = 1;
while(hh <= tt)
{
int t = que[hh++];
for(int i = h[t]; i != -1; i = ne[i])
{
int j = e[i];
if(d[j]) continue;
d[j] = d[t] + 1;
fa[j][0] = t;
for(int k = 1; k <= 15; k++) fa[j][k] = fa[fa[j][k - 1]][k - 1];
que[++tt] = j;
}
}
}
int lca(int a, int b) //计算 a 和 b 的最近公共祖先
{
if(d[a] < d[b]) swap(a, b);
for(int k = 15; k >= 0; k--)
if(d[fa[a][k]] >= d[b])
a = fa[a][k];
if(a == b) return a;
for(int k = 15; k >= 0; k--)
if(fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
//先按照左端点所在块编号从小到大排序,再按照右端点从小到大排序
bool cmp(const Query &a, const Query &b)
{
if(id[a.l] != id[b.l]) return id[a.l] < id[b.l];
return a.r < b.r;
}
void add(int x, int &res) //维护区间中加入节点 x
{
st[x] ^= 1;
if(st[x] == 1)
{
cnt[w[x]]++;
if(cnt[w[x]] == 1) res++;
}
else
{
cnt[w[x]]--;
if(cnt[w[x]] == 0) res--;
}
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++)
{
scanf("%d", &w[i]);
nums.push_back(w[i]);
}
//离散化
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
for(int i = 1; i <= n; i++) w[i] = find(w[i]); //将每个数替换成离散化后的数
memset(h, -1, sizeof h); //初始化邻接表
for(int i = 0; i < n - 1; i++)
{
int a, b;
scanf("%d%d", &a, &b);
add_edge(a, b), add_edge(b, a); //添加无向边
}
dfs(1, -1); //深搜预处理 seq[], first[], last[]
bfs(); //预处理 d[], fa[][]
for(int i = 0; i < m; i++)
{
int a, b;
scanf("%d%d", &a, &b);
if(first[a] > first[b]) swap(a, b); //保证 first[a] < first[b];
int p = lca(a, b); //计算 a 和 b 的最近公共祖先
if(a == p) q[i] = {i, first[a], first[b], 0}; //如果 a == p,则查询区间为 [first[a], first[b]]
else q[i] = {i, last[a], first[b], p}; //如果 a != p,则查询区间为 [last[a], first[b]] + p
}
int len = sqrt(top); //计算块的长度
for(int i = 1; i <= top; i++) id[i] = (i - 1) / len; //计算每个下标所在的块编号
//将所有询问先按照左端点所在块编号从小到大排序,再按照右端点从小到大排序
sort(q, q + m, cmp);
//莫队算法
for(int k = 0, i = 1, j = 0, ans = 0; k < m; k++)
{
int num = q[k].id, l = q[k].l, r = q[k].r, p = q[k].p;
//由于加入操作和删除操作写法相同,因此不再重复写
while(j < r) add(seq[++j], ans); //j 右移,加入操作
while(j > r) add(seq[j--], ans); //j 左移,删除操作
while(i < l) add(seq[i++], ans); //i 右移,删除操作
while(i > l) add(seq[--i], ans); //i 左移,加入操作
if(p) add(p, ans); //如果 p != 0,则需要将 p 加入
res[num] = ans; //记录答案
if(p) add(p, ans); //还原
}
for(int i = 0; i < m; i++) printf("%d\n", res[i]);
return 0;
}
LaTeX 出问题了
已修