题目描述
难度分:2400
输入n(2≤n≤2×105)和一棵n个节点的无权树的n−1条边。节点编号从1开始。
你需要执行n−1次操作,每次操作包含如下三步:
- 选择两个叶节点。
- 把两个叶节点之间的最短距离(简单路径的长度)加到答案中(答案初始为0)。
- 删除这两个叶节点中的一个。
输出答案的最大值,以及n−1行,每行三个数,表示你选的两个叶结点的编号,以及删除的叶结点的编号。
输入样例1
3
1 2
1 3
输出样例1
3
2 3 3
2 1 1
输入样例2
5
1 2
1 3
2 4
2 5
输出样例2
9
3 5 5
4 3 3
4 1 1
4 2 2
算法
贪心+DFS
+倍增
非常难的一个题,学习了好久才理解。
贪心
由于每次都要往答案上累加两点之间的距离,因此肯定是最长距离加的次数越多越好,而树上两点的最长距离就是树的直径。假设st和ed是直径的两个端点,对于不在直径上的点x,每次可以选择一个这样的x和一个直径的端点,x到st的距离更大就往答案累加上x到st的距离,x到ed的距离更大就往答案累加上x到ed的距离。
可以证明任何直径外的点y都不会比x到直径端点的距离更长。证明:假设x到st的距离长于x到ed的距离,如果x到一个直径外的点y距离比x到st的距离还长,则st到ed就不应该是直径,直径应该是y到ed才对。
因此我们可以先按这种办法删除直径外的点,最后再从端点开始把直径上的点一个个删掉。这还只是本题思维上的难度,实现难度其实也很大。
DFS
求树直径
这里用到一种两次DFS
求直径的方法,先随便选定一个节点1,DFS
求出离1最远的节点x。然后再从节点x开始进行DFS
,求出离x最远的节点y,此时x到y就是树的直径。这里也可以反证一下,如果x到y不是直径,则说明y不是距离x最远的点,矛盾。
倍增
接下来随便以一个直径端点作为树根跑一遍倍增,比如st作为根。跑完后从ed开始向上找父节点,用一个布尔数组flag标记整个直径。
模拟
最后就可以开始模拟删点了,第一步是要找到直径之外的叶子节点。我们可以在之前跑DFS
的时候顺便预处理出一个dfn数组,将所有节点按照dfn值进行排序,这样dfn值大的就是叶子节点。按照这个顺序选点,每次选择一个点node,比较它到st和ed两个直径端点的距离,由于之前跑过倍增,因此这两个距离也可以通过LCA快速求得。
删除完直径外的点后,从st或ed开始不断把直径上的点删掉就好。说起来跟简单,其实写起来是想当麻烦的,也很容易写错,详见代码。
复杂度分析
时间复杂度
两次DFS
求树的直径时间复杂度都是O(n);标记直径在最差情况下整棵树就是一条链,时间复杂度也是O(n);倍增的时间复杂度为O(nlog2n);对所有节点按照dfn序排列时间复杂度为O(nlog2n);最后删点由于每个点只会被删一次,所以时间复杂度为O(n)。
综上,算法的瓶颈在于O(nlog2n)的部分,因此这也是算法的时间复杂度。
空间复杂度
空间瓶颈在于倍增时使用的数组f,是O(nlog2n)规模的,其他数组都是O(n)的。因此,整个算法的额外空间复杂度为O(nlog2n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010;
vector<int> graph[N];
int n, maxd, id, ts, st, ed; // st和ed是直径的两个端点
int depth[N], dist[N], f[N][35], fa[N], dfn[N], p[N];
bool flag[N];
void dfs(int u, int a) {
dist[u] = dist[a] + 1;
fa[u] = a, dfn[u] = ++ts;
if(dist[u] > maxd) maxd = dist[u], id = u;
for(int v: graph[u]) {
if(v == a) continue;
dfs(v, u);
}
}
void bfs(int root) {
memset(depth, 0x3f, sizeof(depth));
queue<int> q;
q.push(root);
depth[root] = 1, depth[0] = 0;
while(!q.empty()) {
int ver = q.front();
q.pop();
for(int j: graph[ver]) {
if(depth[j] > depth[ver] + 1) {
depth[j] = depth[ver] + 1;
q.push(j);
f[j][0] = ver;
for(int k = 1; k <= 30; k++) {
f[j][k] = f[f[j][k - 1]][k - 1];
}
}
}
}
}
int lca(int a, int b) {
if(depth[a] < depth[b]) {
swap(a, b);
}
for(int k = 30; k >= 0; k--) {
if(depth[f[a][k]] >= depth[b]) {
a = f[a][k];
}
}
if(a == b) return a;
for(int k = 30; k >= 0; k--) {
if(f[a][k] != f[b][k]) {
a = f[a][k], b = f[b][k];
}
}
return f[a][0];
}
bool cmp(int x, int y) {
// dfn值更大的为叶子
return dfn[x] > dfn[y];
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
graph[i].clear();
p[i] = i;
}
for(int i = 1; i < n; i++) {
int a, b;
scanf("%d%d", &a, &b);
graph[a].push_back(b);
graph[b].push_back(a);
}
// 找到离1最远的点id,赋值给st
dfs(1, 0);
st = id, flag[st] = true, maxd = 0, ts = 0;
// 找到离st最远的点id,赋值给ed
memset(dist, 0, sizeof(dist));
dfs(st, 0);
ed = id, flag[ed] = true;
int len = dist[ed]; // 直径的长度
int cur = ed;
while(fa[cur] != st) {
cur = fa[cur];
flag[cur] = true;
}
// 倍增
bfs(st);
// 删直径之外的点
sort(p + 1, p + n + 1, cmp);
LL ans = 0;
vector<array<int, 3>> op;
for(int i = 1; i <= n; i++) {
int x = p[i];
if(flag[x]) continue;
int a = lca(x, ed);
int d1 = depth[x] - 1, d2 = depth[x] + depth[ed] - 2*depth[a];
ans += max(d1, d2);
int node = d1 >= d2? st: ed;
op.push_back({node, x, x});
}
// 删直径上的点
cur = ed;
len--;
while(fa[cur] != st) {
ans += len--;
op.push_back({st, cur, cur});
cur = fa[cur];
}
ans++;
op.push_back({st, cur, cur});
printf("%lld\n", ans);
for(auto& tup: op) {
printf("%d %d %d\n", tup[0], tup[1], tup[2]);
}
return 0;
}