对这题有疑问的同学建议先看上面这两篇题解,先看第二篇,再看第一篇。主要的讲解在第二篇里面,第一篇
只是在代码上对第二篇的代码进行了优化而已
以下代码在洛谷是过了20个数据点,在acwing过了19个,最后一个数据点过不了,MLE了
还是因为我把这道题当做acwing353的特殊版本来做了,给每个点都开了一棵线段树
所以导致了MLE
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<map>
using namespace std;
const int N = 300010, M = 29 * N;
int leftson[M], rightson[M], sum[M], root[N], cnt = 0, w[N], ans[N];
int n, m;
// 上面这样做可以极大地节省空间复杂度
int father[N], son[N], howbig[N], top[N], timestamp = 0, depth[N], dfn[N], fi[N], ne[N << 1], en[N << 1], index = 0;
void add(int a, int b) {
index++;
ne[index] = fi[a];
fi[a] = index;
en[index] = b;
}
void dfs1(int u, int dad) {
father[u] = dad;
depth[u] = depth[dad] + 1;
howbig[u] = 1;
for (int p = fi[u]; p > 0; p = ne[p]) {
int v = en[p];
if (v == dad) continue;
dfs1(v, u);
howbig[u] += howbig[v];
if (howbig[v] > howbig[son[u]]) son[u] = v;
}
}
void dfs2(int u, int t) {
top[u] = t;
dfn[u] = ++timestamp;
if (son[u]) dfs2(son[u], t);
for (int p = fi[u]; p > 0; p = ne[p]) {
int v = en[p];
if (v == son[u] || v == father[u]) continue;
dfs2(v, v);
}
}
int lca(int a, int b) {
while (top[a] != top[b]) {
if (depth[top[a]] < depth[top[b]]) swap(a, b);
a = father[top[a]];
}
if (depth[a] < depth[b]) return a;
return b;
}
void pushup(int u) {
int left = leftson[u], right = rightson[u];
sum[u] = sum[left] + sum[right];
}
void update(int u, int l, int r, int x, int c) {
if (l == r) {
sum[u] += c;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) {
if (leftson[u] == 0) leftson[u] = ++cnt;
update(leftson[u], l, mid, x, c);
}
else {
if (rightson[u] == 0) rightson[u] = ++cnt;
update(rightson[u], mid + 1, r, x, c);
}
pushup(u);
}
struct Q {
int a, b, x;
}op[N];
int queue[N], head = 1, tail = 0;
void bfs() {
queue[++tail] = 1;
while (tail - head + 1 > 0) {
int u = queue[head++];
for (int p = fi[u]; p > 0; p = ne[p]) {
int v = en[p];
if (depth[v] != depth[u] + 1) continue;
queue[++tail] = v;
}
}
}
void merge(int p, int q, int l, int r) {
if (q == 0) return;
if (l == r) {
sum[p] += sum[q];
return;
}
int mid = (l + r) >> 1;
if (leftson[p] == 0 && leftson[q] != 0) {
leftson[p] = leftson[q];
}
else if (leftson[p] != 0 && leftson[q] != 0) {
merge(leftson[p], leftson[q], l, mid);
}
if (rightson[p] == 0 && rightson[q] != 0) {
rightson[p] = rightson[q];
}
else if (rightson[p] != 0 && rightson[q] != 0) {
merge(rightson[p], rightson[q], mid + 1, r);
}
pushup(p);
}
int query(int u, int l, int r, int x) {
if (u == 0) return 0;
if (l == r) return sum[u];
int mid = (l + r) >> 1;
if (x <= mid) return query(leftson[u], l, mid, x);
return query(rightson[u], mid + 1, r, x);
}
int main() {
scanf_s("%d%d", &n, &m);
for (int i = 1; i <= n - 1; i++) {
int a, b;
scanf_s("%d%d", &a, &b);
add(a, b); add(b, a);
}
dfs1(1, 0); dfs2(1, 1);
for (int i = 1; i <= n; i++) {
scanf_s("%d", &w[i]);
}
for (int i = 1; i <= m; i++) {
scanf_s("%d%d", &op[i].a, &op[i].b);
op[i].x = lca(op[i].a, op[i].b);
}
for (int i = 1; i <= m; i++) {
int a = op[i].a, b = op[i].b, x = op[i].x;
if (root[a] == 0) root[a] = ++cnt;
if (root[b] == 0) root[b] = ++cnt;
if (root[x] == 0) root[x] = ++cnt;
if (root[father[x]] == 0) root[father[x]] = ++cnt;
update(root[a], 1, n << 1, n + depth[a], 1);
update(root[b], 1, n << 1, n + depth[x] * 2 - depth[a], 1);
update(root[x], 1, n << 1, n + depth[x] * 2 - depth[a], -1);
update(root[father[x]], 1, n << 1, n + depth[a], -1);
/*
上面这个操作就是说如果u这个点处于a ~ x中的某个点,那么很明显在u能够观察到这个玩家当且仅当depth[a] - depth[u] = w[u]
如果u是在x ~ b当中的话,能观察到当且仅当depth[a] - depth[x] + depth[u] - depth[x] = w[u]
所以我们可以看成对所有a ~ x上面的点都发放了一个depth[a]这种救济粮,所有x ~ b上面的点都发放了一个depth[x] - 2 * depth[a]这种救济粮
至于x的话你放在哪边都是一样的。查询的时候,要查询u这个点出现在上升段的路径有几条就是查询depth[u] + w[u]咯,查询他出现在下降段的路径有几条自然就是查询depth[u] - w[u]就好,
你把depth[a] - depth[x] + depth[u] - depth[x] = w[u]这个式子化简一下就知道为什么是这个条件了
而上面的n + depth[a]跟n + depth[x] * 2 - depth[a],前面多加了一个n无非就是一种离散化嘛,因为depth[x] * 2 - depth[a]有可能是< 0的
上面这种差分方法很明显就是把x归入到“下降段”了吧
*/
}
bfs();
for (int i = n; i >= 1; i--) {
int u = queue[i];
for (int p = fi[u]; p > 0; p = ne[p]) {
int v = en[p];
if (depth[v] != depth[u] + 1) continue;
if (root[u] == 0) root[u] = ++cnt;
merge(root[u], root[v], 1, n << 1);
}
if (w[u] && depth[u] + w[u] <= n) ans[u] += query(root[u], 1, n << 1, n + depth[u] + w[u]);
ans[u] += query(root[u], 1, n << 1, n + depth[u] - w[u]);
/*
上面这两个query其实就是查询,把u作为“上升段”中的一点的路径跟作为“下降段”中的一点的的路径加起来一共有多少条
那么这个w[u] && depth[u] + w[u] <= n是什么意思呢?就是说当u作为上升段中的一点的时候,在这一点能观察到某个a出发的玩家当且仅当
depth[a] - depth[u] = w[u] <=> depth[a] = depth[u] + w[u]那么我们可以发现,如果depth[u] + w[u] > n的话那他是绝对不可能作为上升段的某点出现的
但是这并不代表着这个时候你query(root[u], 1, n << 1, n + depth[u] + w[u])会得到0,因为depth[u] + w[u]这个位置的数可能是“下降段”贡献出来的,所以你不能要
同理如果w[u] = 0的话这个u要出现在上升段那只能是depth[a] = depth[u]了,那u就是起点呀,说明整条路径都是“下降段”了,根本没有上升段。所以这个时候也不能要*/
}
for (int i = 1; i <= n; i++) {
printf_s("%d ", ans[i]);
}
}
// 上面这个代码在洛谷已经AC了,但是在acwing只能ac19个点,有1个数据点就是ac不了。