题目描述
给定两棵树:树 1 有 N1 个顶点(编号 1 到 N1),树 2 有 N2 个顶点(编号 1 到 N2)。
可以在树 1 的顶点 i 和树 2 的顶点 j 之间添加一条双向边,从而将两棵树合并成一棵大树。
令 f(i,j) 表示这棵合并后的大树的 直径。
需要计算所有可能的连接方式对应的直径之和:
N1∑i=1N2∑j=1f(i,j)
其中,树中两点间的 距离 定义为它们之间唯一简单路径上的边数。树的 直径 定义为树中任意两点间距离的最大值。
样例
输入:
3
1 3
1 2
3
1 2
3 1
输出:
39
说明:
例如,连接树 1 的顶点 2 和树 2 的顶点 3。树 1 的边为 (1,3), (1,2)。树 2 的边为 (1,2), (3,1)。
合并后的树有顶点 {1₁, 2₁, 3₁, 1₂, 2₂, 3₂} 和边 {(1₁,3₁), (1₁,2₁), (1₂,2₂), (3₂,1₂), (2₁,3₂)}。
这棵树的直径是 5(例如,从 3₁ 到 2₂ 的路径:3₁-1₁-2₁-3₂-1₂-2₂)。所以 f(2,3)=5。
计算所有 3×3=9 种连接方式的 f(i,j) 并求和,得到 39。
算法 (树的直径 + 排序 + 双指针)
O(N1logN1+N2logN2)
思路分析
-
计算单棵树的直径和最远距离:
- 对于一棵树,计算其直径 D 和每个节点 v 到树中其他节点的最大距离 max\_dist(v) 是解决问题的基础。
- 计算直径 D: 可以通过两次 BFS/DFS 实现。
- 从任意节点 s 开始 BFS/DFS,找到距离 s 最远的节点 a。
- 从节点 a 开始 BFS/DFS,找到距离 a 最远的节点 b。
- 节点 a 和 b 之间的距离就是树的直径 D。节点 a 和 b 是直径的两个端点。
- 计算 max\_dist(v): 对于树中的任意节点 v,其距离最远的节点一定是直径的两个端点之一(a 或 b)。因此,max\_dist(v)=max。
- 我们可以在第二次 BFS/DFS(从 a 开始)时记录下所有节点到 a 的距离 \text{dist}(\cdot, a)。
- 再进行一次 BFS/DFS,从 b 开始,记录下所有节点到 b 的距离 \text{dist}(\cdot, b)。
- 然后对每个节点 v,计算 \max(\text{dist}(v, a), \text{dist}(v, b)) 即可得到 \text{max\_dist}(v)。
- 我们分别对树 1 和树 2 执行以上操作,得到它们的直径 D_1, D_2 以及各自节点的最远距离数组 M_1[1..N_1] 和 M_2[1..N_2]。
-
分析合并后树的直径 f(i, j):
当树 1 的顶点 i 和树 2 的顶点 j 通过一条新边连接后,形成的新树 T_{i,j} 的直径可能是以下三种情况中的最大值:- Case 1: 直径完全位于原树 1 内部。此时直径为 D_1。
- Case 2: 直径完全位于原树 2 内部。此时直径为 D_2。
- Case 3: 直径经过新添加的边 (i, j)。这种路径的形式是从树 1 的某个节点 u 出发,到达 i,经过新边到 j,再到达树 2 的某个节点 v。其长度为 \text{dist}_1(u, i) + 1 + \text{dist}_2(j, v)。要使这个长度最大,应该选择 u 为距离 i 最远的点(距离为 M_1[i]),选择 v 为距离 j 最远的点(距离为 M_2[j])。所以,这种情况下的最大路径长度为 M_1[i] + M_2[j] + 1。
综上所述,合并后的直径为:
f(i, j) = \max(D_1, D_2, M_1[i] + M_2[j] + 1)
令 D_{\max} = \max(D_1, D_2)。则:
f(i, j) = \max(D_{\max}, M_1[i] + M_2[j] + 1) -
计算总和 \sum \sum f(i, j):
直接计算 N_1 \times N_2 对 (i, j) 的 f(i, j) 值并求和,复杂度为 O(N_1 N_2),对于 N_1, N_2 \le 2 \times 10^5 来说太慢了。我们需要更快的算法。
目标是计算:
\text{TotalSum} = \sum_{i=1}^{N_1} \sum_{j=1}^{N_2} \max(D_{\max}, M_1[i] + M_2[j] + 1)
我们可以将 \max 函数拆开:
\max(A, B) = \begin{cases} A & \text{if } B \le A \\ B & \text{if } B > A \end{cases}
令 A = D_{\max},B = M_1[i] + M_2[j] + 1。
\text{TotalSum} = \sum_{i=1}^{N_1} \left( \sum_{j \text{ s.t. } M_1[i] + M_2[j] + 1 \le D_{\max}} D_{\max} + \sum_{j \text{ s.t. } M_1[i] + M_2[j] + 1 > D_{\max}} (M_1[i] + M_2[j] + 1) \right) -
使用排序和双指针优化:
- 为了快速处理条件 M_1[i] + M_2[j] + 1 > D_{\max} (等价于 M_2[j] > D_{\max} - 1 - M_1[i]),我们可以先将数组 M_1 和 M_2 分别排序。设排序后的数组为 f_1 和 f_2。
- 现在,我们遍历排序后的 f_1 中的每个元素 f_1[i’] (对应原树中的某个 i)。对于固定的 f_1[i’],我们需要找到 f_2 中满足 f_1[i’] + f_2[j’] + 1 > D_{\max} 的元素个数和它们的和。
- 由于 f_1 和 f_2 都是排序的,我们可以使用 双指针 方法。
- 对 f_1 使用一个指针
i
从 0 遍历到 N_1 - 1。 - 对 f_2 使用一个指针
j
从 N_2 开始(表示还没考虑 f_2 中的元素)。 - 对于当前的 f_1[i],我们找到一个临界点
j
,使得对于所有 k < j,满足 f_1[i] + f_2[k] + 1 \le D_{\max},而对于所有 k \ge j,满足 f_1[i] + f_2[k] + 1 > D_{\max}。 - 由于 f_1 是递增的,当
i
增加时,f_1[i] 增加,使得满足 f_1[i] + f_2[k] + 1 \le D_{\max} 的 k 的范围会缩小。这意味着临界点j
是单调不增的。 - 我们可以维护指针
j
。对于每个i
,我们将j
向左移动(减小),直到找到第一个满足 f_1[i] + f_2[j-1] + 1 \le D_{\max} 的位置(或者说,j
是第一个不满足 f_1[i] + f_2[k] + 1 > D_{\max} 的下标)。
- 对 f_1 使用一个指针
- 计算贡献:
- 对于固定的
i
和找到的临界指针j
:- 有
j
个 f_2 元素 (f_2[0] 到 f_2[j-1]) 满足 f_1[i] + f_2[k] + 1 \le D_{\max}。它们的贡献是 j \times D_{\max}。 - 有 N_2 - j 个 f_2 元素 (f_2[j] 到 f_2[N_2-1]) 满足 f_1[i] + f_2[k] + 1 > D_{\max}。它们的贡献是 \sum_{k=j}^{N_2-1} (f_1[i] + f_2[k] + 1)。
- 这个和可以拆分为 (N_2 - j) \times (f_1[i] + 1) + \sum_{k=j}^{N_2-1} f_2[k]。
- 为了快速计算 \sum_{k=j}^{N_2-1} f_2[k],我们可以预先计算 f_2 的后缀和,或者在移动指针
j
时动态维护 f_2 的后半部分的和。代码中使用的是后者:在while
循环将j
左移时,累加 f_2[j] 到变量sum
中。此时sum
就是 \sum_{k=j}^{N_2-1} f_2[k]。
- 有
- 对于固定的
- 算法流程:
- 对树 1 计算直径 D_1 和最远距离数组 M_1。
- 对树 2 计算直径 D_2 和最远距离数组 M_2。
- 令 D_{\max} = \max(D_1, D_2)。
- 将 M_1 排序得到 f_1。
- 将 M_2 排序得到 f_2。
- 初始化总和
ans = 0
。 - 初始化指针
j = N_2
和sum = 0
(表示 f_2 中 k \ge j 的元素和)。 - 遍历
i
从 0 到 N_1 - 1:
a.while (j > 0 && f_1[i] + f_2[j - 1] + 1 > D_{\max})
:
i.j--
ii.sum += f_2[j]
b. 计算当前i
的贡献:contribution = (N_2 - j) * (f_1[i] + 1) + sum + j * D_{\max}
。
c.ans += contribution
。 - 输出
ans
。
时间复杂度
- 计算树 1 的直径和最远距离:O(N_1) (两次 BFS/DFS)。
- 计算树 2 的直径和最远距离:O(N_2)。
- 排序 M_1:O(N_1 \log N_1)。
- 排序 M_2:O(N_2 \log N_2)。
- 双指针计算总和:指针 i 移动 N_1 次,指针 j 总共最多移动 N_2 次。复杂度为 O(N_1 + N_2)。
- 整体时间复杂度:O(N_1 \log N_1 + N_2 \log N_2)。
C++ 代码
#include <bits/stdc++.h> // 引入所有标准库
using i64 = long long; // 定义 i64 为 long long 的别名
// using u64 = unsigned long long; // 未使用
// using u32 = unsigned; // 未使用
// using u128 = unsigned __int128; // 未使用
// 函数 solve: 处理一棵树,计算其直径和每个节点的最远距离数组
// 返回值:pair<直径长度, 最远距离数组>
auto solve() {
int N; // 树的顶点数
std::cin >> N;
std::vector<std::vector<int>> adj(N); // 邻接表
// 读入 N-1 条边
for (int i = 1; i < N; i++) {
int u, v;
std::cin >> u >> v;
u--; // 转换为 0-based 索引
v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
// f[i] 将用于存储节点 i 的最远距离
std::vector<int> f(N);
// dis[i] 用于 BFS 中存储距离
std::vector<int> dis(N, -1);
// bfs 函数:从源点 s 开始 BFS,计算所有点到 s 的距离,并返回距离 s 最远的点
auto bfs = [&](int s) {
std::queue<int> q;
q.push(s);
dis.assign(N, -1); // 重置距离数组
dis[s] = 0; // 源点距离为 0
while (!q.empty()) {
int x = q.front();
q.pop();
for (auto y : adj[x]) { // 遍历邻居
if (dis[y] == -1) { // 如果邻居未访问
dis[y] = dis[x] + 1; // 更新距离
q.push(y); // 加入队列
}
}
}
// 找到距离最远的点并返回其索引
return std::max_element(dis.begin(), dis.end()) - dis.begin();
};
// 1. 第一次 BFS:从任意点 (0) 开始,找到最远点 a
auto a = bfs(0);
// 2. 第二次 BFS:从点 a 开始,找到最远点 b,并记录下所有点到 a 的距离到 f 数组
auto b = bfs(a);
f = dis; // f[i] = dist(i, a)
// 3. 第三次 BFS:从点 b 开始,计算所有点到 b 的距离
bfs(b); // dis[i] = dist(i, b)
// 4. 计算每个节点的最远距离:f[i] = max(dist(i, a), dist(i, b))
for (int i = 0; i < N; i++) {
f[i] = std::max(f[i], dis[i]);
}
// 返回直径长度 (f[b] 即 dist(a, b)) 和最远距离数组 f
return std::pair(f[b], f);
}
int main() {
// 加速 IO
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
// 处理第一棵树
auto [dia1, f1] = solve(); // dia1: 直径, f1: 最远距离数组
// 处理第二棵树
auto [dia2, f2] = solve(); // dia2: 直径, f2: 最远距离数组
i64 ans = 0; // 最终答案,使用 long long
int n1 = f1.size(); // 树 1 的顶点数
int n2 = f2.size(); // 树 2 的顶点数
// 对最远距离数组进行排序
std::sort(f1.begin(), f1.end());
std::sort(f2.begin(), f2.end());
// 计算两棵树直径的最大值
int dia = std::max(dia1, dia2);
// 双指针算法计算总和
i64 sum = 0; // 用于累加 f2 中满足条件的元素之和 (f1[i] + f2[k] + 1 > dia)
// 指针 j 从 n2 开始 (指向 f2 数组末尾之后)
for (int i = 0, j = n2; i < n1; i++) { // 指针 i 遍历排序后的 f1
// 移动指针 j 向左,直到找到临界点
// 使得 f2[j-1] 是最后一个满足 f1[i] + f2[j-1] + 1 > dia 的元素 (或者 j=0)
while (j > 0 && f1[i] + f2[j - 1] + 1 > dia) {
j--; // j 左移
sum += f2[j]; // 将 f2[j] 加入 sum (这部分对应 B > A 的情况)
}
// 计算当前 f1[i] 的贡献
// 1. 对于 f2[j] 到 f2[n2-1] (共 n2 - j 个元素), f(i, k) = f1[i] + f2[k] + 1
// 这部分的和为 (f1[i] + 1) * (n2 - j) + sum_of_f2[j..n2-1]
// sum 变量已经计算了 sum_of_f2[j..n2-1]
ans += 1LL * (f1[i] + 1) * (n2 - j) + sum;
// 2. 对于 f2[0] 到 f2[j-1] (共 j 个元素), f(i, k) = dia
// 这部分的和为 j * dia
ans += 1LL * j * dia;
}
// 输出最终答案
std::cout << ans << "\n";
return 0; // 程序正常结束
}