树上莫队
例题
给定一棵树,每个节点有一个数值,m次询问
每次询问包含两个节点a,b
求a到b的最短路径上的节点有多少个不同的数值
欧拉序
对于以u为根的子树
欧拉序为u
+以所有子节点为根的子树的欧拉序
+u
同时,第一次出现u的位置记为first[u],最后一次出现u的位置记为last[u]
思想
给定两节点a,b(a的深度<=b的深度)
- 若a为b的祖节点,则答案记为
[first[a], first[b]]区间中只出现1次的数值的数量
- 否则答案记为
a,b两点的最近公共祖+[last[a],first[b]]区间中只出现1次的数值的数量
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 160010;
int n, m, len;
int h[N], e[N], ne[N], idx;
int w[N], seq[N], first[N], last[N], top;
int q[N], dep[N], fa[N][25], cnt[N];
int ans[N];
bool st[N];
vector<int> nums;
struct Node {
int id, l, r, p;
}Q[N];
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
int get(int i)
{
return i / len;
}
bool cmp(Node a, Node b)
{
int al = get(a.l), bl = get(b.l);
if (al != bl)return al < bl;
return a.r < b.r;
}
void dfs(int u, int fa)
{
seq[++ top] = u;//seq表示欧拉序
first[u] = top;//first[u]表示u第一次出现的位置
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != fa)dfs(j, u);
}
seq[++ top] = u;
last[u] = top;//last[u]表示u最后一次出现的位置
}
void bfs(int s)
{
memset(dep, -1, sizeof dep);
int hh = 0, tt = 0;
q[0] = s, dep[s] = 0;
while (hh <= tt)
{
int t = q[hh ++];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (dep[j] == -1)
{
dep[j] = dep[t] + 1;
fa[j][0] = t, q[++ tt] = j;
for (int k = 1; k <= 20; k ++ )
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
int lca(int a, int b)
{
if (dep[a] < dep[b])swap(a, b);
for (int i = 20; i >= 0; i -- )
if (dep[fa[a][i]] >= dep[b])
a = fa[a][i];
if (a == b)return a;
for (int i = 20; i >= 0; i -- )
if (fa[a][i] != fa[b][i])
a = fa[a][i], b = fa[b][i];
return fa[a][0];
}
void change(int x, int& res)
{
st[x] ^= 1;//满足只出现1次
if (st[x] == 0)
{
cnt[w[x]] -- ;
if (!cnt[w[x]])res -- ;
}
else
{
if (!cnt[w[x]])res ++ ;
cnt[w[x]] ++ ;
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i ++ )scanf("%d", &w[i]), nums.push_back(w[i]);
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
for (int i = 1; i <= n; i ++ )
w[i] = lower_bound(nums.begin(), nums.end(), w[i]) - nums.begin();
memset(h, -1, sizeof h);
for (int i = 1; i < n; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
bfs(1), dfs(1, -1);
len = sqrt(top);
for (int i = 1; i <= m; i ++ )
{
int x, y;
scanf("%d%d", &x, &y);
if (first[x] > first[y])swap(x, y);
int p = lca(x, y);
if (x == p)Q[i] = {i, first[x], first[y]};
else Q[i] = {i, last[x], first[y], p};
}
sort(Q + 1, Q + m + 1, cmp);
for (int k = 1, i = 1, j = 0, res = 0; k <= m; k ++ )
{
int id = Q[k].id, l = Q[k].l, r = Q[k].r, p = Q[k].p;
while (i < l)change(seq[i ++], res);
while (i > l)change(seq[-- i], res);
while (j < r)change(seq[++ j], res);
while (j > r)change(seq[j --], res);
if (p)change(p, res);
ans[id] = res;
if (p)change(p, res);
}
for (int i = 1; i <= m; i ++ )printf("%d\n", ans[i]);
return 0;
}