题目描述
难度分:2400
你有一棵树,一开始有4个节点,编号为1,2,3,4,其中2,3,4都和1相连。
输入q(1≤q≤5×105)表示有q次操作。每次操作,输入v(1≤v≤ 当前树的大小),保证v是叶子。
在叶子v的下面添加两个新的节点与v相连,编号分别为n+1和n+2,其中n是当前树的大小。
每次操作后,输出树的直径长度。
输入样例
5
2
3
4
8
5
输出样例
3
4
4
5
6
算法
倍增求LCA
如果加入一个新的节点后,树的直径变长了,那么新直径的一个端点肯定是新加入的这个节点。
初始化直径的长度diameter=2,两个端点分别为end1和end2。每次加入一个节点cur的时候,如果这个节点到end1和end2中任意一点的距离超过了diameter,就可以把直径长度更新为这个距离。比如cur到end1的距离超过了diameter,那直径长度就更新为cur到end1的距离,end2就更新为cur。
可以证明任何直径外的点y都不会比cur到直径端点的距离更长。证明:假设cur到end1的距离长于cur到end2的距离。如果cur到一个直径外的点y距离比cur到end1的距离还长,则end1到end2就不应该是直径,直径应该是y到end2才对。
这样一来,先基于初始的4个节点预处理出倍增的dist、depth、fa三个数组。然后每加入一个新的节点cur,就更新这三个数组,计算新节点到上一轮直径端点end1和end2的距离,从而判断直径端点和长度该如何更新。
复杂度分析
时间复杂度
倍增初始化的时间复杂度为O(nlog2U),U为q个询问完成后的树节点总数,初始情况下n=4很小。接下来处理每个询问,新加入一个节点就要更新一次fa数组、dist数组和depth数组,时间复杂度为O(log2U),每次加入两个节点,时间复杂度仍然是这个级别。求新加入节点与老的直径端点end1和end2的距离,时间复杂度为O(log2U)。因此,处理q个询问的时间复杂度为O(qlog2U)。
综上,时间复杂度和树的节点个数并没有关系,只跟预估的树节点总数有关,时间复杂度为O(nlog2U+qlog2U)。
空间复杂度
空间消耗就是倍增求LCA
的辅助数组消耗,dist和depth数组是线性空间,为O(U)。fa数组还需要考虑到每个节点距离根节点的路径长度,空间消耗为O(Ulog2U)。因此,整个算法的额外空间复杂度为O(Ulog2U)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1000010, M = 20;
int q, v, dist[N], depth[N], fa[N][M];
void bfs(unordered_map<int, vector<int>>& graph, int root, int d) {
depth[root] = 1;
dist[root] = 0;
fa[root][0] = 0;
queue<int> q;
q.push(root);
while(!q.empty()) {
int cur = q.front();
q.pop();
for(int nxt: graph[cur]) {
if(depth[nxt]) continue;
q.push(nxt);
depth[nxt] = depth[cur] + 1;
dist[nxt] = dist[cur] + 1;
fa[nxt][0] = cur;
for(int j = 1; j < M; j++) {
fa[nxt][j] = fa[fa[nxt][j - 1]][j - 1];
}
}
}
}
int lca(int a, int b) {
if(depth[a] < depth[b]) swap(a, b);
for(int i = M - 1; i >= 0; i--) {
if(depth[fa[a][i]] >= depth[b]) {
a = fa[a][i];
}
}
if(a == b) return a;
for(int i = M - 1; i >= 0; i--) {
if(fa[a][i] != fa[b][i]) {
a = fa[a][i];
b = fa[b][i];
}
}
return fa[a][0];
}
int get(int x, int y) {
return dist[x] + dist[y] - (dist[lca(x, y)]<<1);
}
int main() {
int end1 = 2, end2 = 3, diameter = 2;
scanf("%d", &q);
unordered_map<int, vector<int>> graph;
graph[1].push_back(2);
graph[2].push_back(1);
graph[1].push_back(3);
graph[3].push_back(1);
graph[1].push_back(4);
graph[4].push_back(1);
bfs(graph, 1, 0);
int n = 4;
for(int k = 1; k <= q; k++) {
scanf("%d", &v);
for(int i = 1; i <= 2; i++) {
int cur = n + i;
fa[cur][0] = v;
dist[cur] = dist[v] + 1;
depth[cur] = depth[v] + 1;
for(int j = 1; j < M; j++) {
fa[cur][j] = fa[fa[cur][j - 1]][j - 1];
}
int d1 = get(cur, end1), d2 = get(cur, end2);
if(d1 >= d2) {
if(d1 > diameter) {
end2 = cur;
diameter = d1;
}
}else {
if(d2 > diameter) {
end1 = cur;
diameter = d2;
}
}
}
n += 2;
printf("%d\n", diameter);
}
return 0;
}