题目描述
难度分:2100
输入n,m(0≤m<n≤3×105)和q(1≤q≤3×105)。
然后输入一个森林的m条边。注:森林由多棵树组成。然后输入q个询问,格式如下:
1 x
:输出节点x所在的树的直径。2 x y
:如果节点x和节点y。
在同一棵树,什么也不做;否则在这两棵树上各选一个点,连一条边,要求连边后,得到的新树的直径最小。所有输入的节点编号均从1开始。
输入样例
6 0 6
2 1 2
2 3 4
2 5 6
2 3 2
2 5 3
1 1
输出样例
4
算法
并查集+DFS
用并查集来维护节点之间的连通性,并查集的根节点作为树的根节点。先DFS
每棵树,预处理出每棵树的直径长度(老套路,先DFS
求一个距离根节点root最远的点x,再DFS
求距离x最远的点y,这样x到y的路径就是这棵树的直径)。对每棵树都追出直径上的节点,将直径的中点存入mid数组中,mid[root]表示以root为根的树的直径中点,直径长度存入dlen数组中,dlen[root]是对应的直径长度。
有个比较显然的贪心策略,就是在合并两棵树的时候,将这两棵树的直径中点连接起来,这样能使新树的直径最小。对于直径为d1和d2的两棵树,合并后的直径为max(d1,d2,⌈d12⌉+⌈d22⌉+1),其中⌈.⌉表示对.向上取整。
接下来就可以在线处理每个询问了,如果x和y不属于同一棵树,就按照上述规则更新新树的直径长度和中点:
- 如果新直径长度为d1,那新树的直径中点就是x树的直径中点。
- 如果新直径长度为d2,那新树的直径中点就是y树的直径中点。
- 否则新树的直径长度取决于x和y两棵树的直径谁更长,新直径中点是较长那条直径的中点。
复杂度分析
时间复杂度
并查集merge的时间复杂度看成是log级别,那么建图时的时间复杂度就是O(mlog2n)。然后遍历每棵子树,对每棵子树做两遍DFS
进行预处理,时间复杂度为O(n+m)。最后在线处理每条询问,每个询问的瓶颈在于并查集的find操作,时间复杂度可以看成O(log2n)级别。综上,整个算法的时间复杂度为O(n+m+(m+q)log2n)。
空间复杂度
空间瓶颈在于图的邻接表,复杂度为O(n+m),其余数组都是线性的,为O(n)。因此,整个算法的额外空间复杂度为O(n+m)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 300010;
int n, m, q, p[N], mid[N], pre[N], dlen[N], dist[N];
vector<int> graph[N];
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
int find(int x) {
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
void merge(int x, int y) {
int rx = find(x), ry = find(y);
if(rx != ry) {
p[rx] = ry;
}
}
void dfs(int u, int fa) {
for(auto v: graph[u]){
if(v == fa) continue;
pre[v] = u;
dist[v] = dist[u] + 1;
dfs(v, u);
}
}
int main() {
scanf("%d%d%d", &n, &m, &q);
for(int i = 1; i <= n; i++) {
graph[i].clear();
p[i] = i;
mid[i] = 0;
}
for(int i = 1; i <= m; i++) {
int x, y;
scanf("%d%d", &x, &y);
graph[x].push_back(y);
graph[y].push_back(x);
merge(x, y);
}
unordered_map<int, vector<int>> groups;
for(int i = 1; i <= n; i++) {
groups[find(i)].push_back(i);
}
for(auto&[root, group]: groups) {
for(int node: group) {
dist[node] = pre[node] = 0;
}
// 求距离root最远的点x
dfs(root, -1);
int maxdist = 0, x = root;
for(int node: group) {
if(dist[node] > maxdist) {
maxdist = dist[node], x = node;
}
}
// 求距离x最远的点
maxdist = 0;
int cur = x;
for(int node: group) {
dist[node] = pre[node] = 0;
}
dfs(x, -1);
for(int node: group) {
if(dist[node] > maxdist) {
maxdist = dist[node], x = node;
}
}
// 把整个直径追出来
vector<int> path;
cur = x;
while(cur) {
path.push_back(cur);
cur = pre[cur];
}
int len = path.size();
int mid_node = path[(len + 1)/2 - 1];
mid[root] = mid_node; // 树root的直径中点
dlen[root] = len - 1;
}
for(int i = 1; i <= q; i++) {
int tp;
scanf("%d", &tp);
if(tp == 1) {
int x;
scanf("%d", &x);
printf("%d\n", dlen[find(x)]);
}else {
int x, y;
scanf("%d%d", &x, &y);
if(find(x) != find(y)) {
int xmid = mid[find(x)], ymid = mid[find(y)];
int d1 = dlen[find(x)], d2 = dlen[find(y)];
merge(find(x), find(y));
dlen[find(x)] = max((d1 + 1 >> 1) + 1 + (d2 + 1 >> 1), max(d1, d2));
if(dlen[find(x)] == d1) {
mid[find(x)] = xmid;
}else if(dlen[find(x)] == d2) {
mid[find(x)] = ymid;
}else {
mid[find(x)] = d1 >= d2? xmid: ymid;
}
}
}
}
return 0;
}