算法
(树上dfs、树的中心、容斥原理) $O(N)$
首先,树的直径可以通过跑两遍 dfs
找出来
其次,虽然树的直径可以有无数条,但树的中心
只有唯一一个
树的直径长度可以为奇数或偶数,下面进行分类讨论:
- 若直径上有偶数个点,则树的中心会出现在某一条边上,但为了方便我们不妨令树的直径长度为直径上的点数 $-1$,记为 $d$,再令中心点所对应的边的两端点分别为 $l$ 和 $r$, 然后只需找左边距离点 $l$ 为 $\frac{d}{2}$ 的点的数量 $s_1$,以及右边距离点 $r$ 为 $\frac{d}{2}$ 的点的数量 $s2$,而 $s_1 \times s_2$ 即是答案
这个例子的答案就是 $2 \times 2 = 4$
- 若直径上有奇数个点,则中心点恰好落在某一顶点上。
在每个子树中,从距离中心点为 $\frac{d}{2}$ 的顶点中至多选一个点(每个子树中这样合法的顶点有 $A_i$ 个)的方案数为 $\prod (A_i + 1)$,最后的答案为 $\prod (A_i + 1) - \sum A_i - 1$
其中 $\sum A_i$ 表示在全体中只选一个点的合法方案数,$1$ 表示在全体中只选 $0$ 个点的合法方案数
这个例子的答案为 $(3 + 1) \times (0 + 1) \times (4 + 1) - (3+0+4) -1 = 12$
C++ 代码
#pragma GCC target("avx2")
#pragma GCC optimize("O3")
#pragma GCC optimize("unroll-loops")
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
#if __has_include(<atcoder/all>)
#include <atcoder/all>
using namespace atcoder;
#endif
#define rep(i, n) for (int i = 0; i < (n); ++i)
using std::cin;
using std::cout;
using std::vector;
using mint = modint998244353;
using P = std::pair<int, int>;
int main() {
int n;
cin >> n;
vector<vector<int>> to(n);
rep(i, n - 1) {
int a, b;
cin >> a >> b;
--a; --b;
to[a].push_back(b);
to[b].push_back(a);
}
auto dfs = [&](auto& f, int v, int d = 0, int p=-1) -> P {
P res(d, v);
for (int u : to[v]) {
if (u == p) continue;
res = max(res, f(f, u, d+1, v));
}
return res;
};
int sv = dfs(dfs, 0).second;
int tv = dfs(dfs, sv).second;
vector<int> vs;
auto dfs2 = [&](auto& f, int v, int tv, int p=-1) -> bool {
if (v == tv) {
vs.push_back(v);
return true;
}
for (int u : to[v]) {
if (u == p) continue;
if (f(f, u, tv, v)) {
vs.push_back(v);
return true;
}
}
return false;
};
dfs2(dfs2, sv, tv); // 搜出直径上所有的点
int d = vs.size() - 1;
auto dfs3 = [&](auto& f, int v, int d, int p) -> int { // 搜索距离树的中心中心为 d 的所有点
if (d == 0) return 1;
int res = 0;
for (int u : to[v]) {
if (u == p) continue;
res += f(f, u, d-1, v);
}
return res;
};
mint ans;
if (d % 2 == 1) {
int l = vs[d/2], r = vs[d/2+1];
ans = dfs3(dfs3, l, d/2, r);
ans *= dfs3(dfs3, r, d/2, l);
}
else {
ans = 1;
int c = vs[d/2];
int s1 = 0;
for (int v : to[c]) {
int r = dfs3(dfs3, v, d/2-1, c);
ans *= r + 1;
s1 += r;
}
ans -= s1;
ans -= 1;
}
cout << ans.val() << '\n';
return 0;
}