看上去这个题目限制很玄学,但实际上可以通过转化变成三个简单的部分。
即:设 $P1$ 为满足 $u$ 是 $u \to v$ 上编号最小的点对数量,$P2$ 为满足 $v$ 是 $u \to v$ 上编号最大的点对数量,$P3$ 为同时满足前两个条件的点对数量。
则显然答案为:$P1 + P2 - 2 \times P3$。
考虑建立两棵点权 Kruskal 重构树,一棵大根一棵小根。那么对于 $P1,P2$,答案是好算的,即每个点在两棵树中的子树大小之和减去它本身 $2$ 的贡献。
对于 $P3$,其实就是在两棵树中计算有祖孙关系并且祖孙关系相反的点对数量。
这个可以用树状数组维护,是简单的。
于是就做完了。
#include <bits/stdc++.h>
using namespace std;
const int N = 4e5 + 15, M = N << 1;
int n, m;
vector<int> son[N];
int p[N];
int find(int x) { return (p[x] == x) ? x : p[x] = find(p[x]); }
struct BIT {
int tr[N];
void add(int x, int d) { for ( ; x < N; x += x & -x) tr[x] += d; }
int ask(int x) {
int res = 0;
for ( ; x ; x -= x & -x) res += tr[x];
return res;
}
int query(int l, int r) { return ask(r) - ask(l - 1); }
} tr;
struct Tree {
int h[N], e[M], ne[M], idx;
void init() { memset(h, -1, sizeof h); idx = 0, tot = 0; }
void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx++; }
int dfn[N], dep[N], sz[N], tot;
void dfs(int u, int father) {
sz[u] = 1, dfn[u] = ++tot, dep[u] = dep[father] + 1;
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == father) continue;
dfs(v, u);
sz[u] += sz[v];
}
}
} mn, mx;
long long ans = 0, res = 0;
void dfs(int u, int father) {
res += tr.query(mx.dfn[u], mx.dfn[u] + mx.sz[u] - 1);
tr.add(mx.dfn[u], 1);
for (int i = mn.h[u]; ~i; i = mn.ne[i]) {
int v = mn.e[i];
if (v == father) continue;
dfs(v, u);
}
tr.add(mx.dfn[u], -1);
}
int main() {
scanf("%d", &n);
for (int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
son[u].push_back(v);
son[v].push_back(u);
}
mn.init(), mx.init();
for (int i = 1; i <= n; i++) p[i] = i;
for (int u = 1; u <= n; u++)
for (int v : son[u]) if (v < u) {
int fv = find(v);
mx.add(u, fv), mx.add(fv, u);
p[fv] = u;
}
for (int i = 1; i <= n; i++) p[i] = i;
for (int u = n; u >= 1; u--)
for (int v : son[u]) if (v > u) {
int fv = find(v);
mn.add(u, fv), mn.add(fv, u);
p[fv] = u;
}
mn.dfs(1, 0), mx.dfs(n, 0);
for (int i = 1; i <= n; i++) ans += mn.sz[i] + mx.sz[i] - 2;
dfs(1, 0), ans -= 2 * res;
printf("%lld\n", ans);
scanf("%d", &m);
while (m--) {
int x; scanf("%d", &x);
ans += n - mn.dep[x];
++n, mn.dep[n] = mn.dep[x] + 1;
printf("%lld\n", ans);
}
return 0;
}