RMJ 为什么炸了RMJ 为什么炸了RMJ 为什么炸了RMJ 为什么炸了RMJ 为什么炸了
非常好的一道树上背包。
题目简述
给定一棵 $n$ 个结点的树,你需要给每个点赋点权为 $0$ 或 $1$,使得每一条路径的点权 $\text{mex}$ 之和最大。
解法
贪心地想可以给树做二分图染色,然而会在一些神秘形态树上寄掉。
于是观察性质,一条路径上出现 $0,1$ 时贡献为 $2$,全 $0$ 贡献为 $1$,全 $1$ 贡献为 $0$。
不妨设所有路径的贡献都为 $2$,此时总贡献为 $n \times (n+1)$,再考虑什么时候会有贡献损失。此时一条全 $0$ 的路径损失为 $2-1=1$,全 $1$ 的路径损失为 $2-0=2$。
发现同色路径一定在同色连通块中,这启发我们树上背包。
设 $dp_{u,i,j}$ 表示以 $u$ 为根的子树内,包含 $u$ 的连通块大小为 $i$,颜色为 $j$ 的最小损失。枚举 $u$ 的儿子 $v$ 转移。
- $u,v$ 异色
- 此时不会有损失贡献,因为跨两个块的路径 $\text{mex}$ 一定是 $2$。
- $dp_{u,i,0} = dp_{u,i,0} + dp_{v,i’,1}$
- $dp_{u,i,1} = dp_{u,i,1} + dp_{v,i’,0}$
- $u,v$ 同色
- 此时贡献与颜色种类和连通块大小有关,颜色是 $0$ 损失是 $1$,颜色是 $1$ 损失是 $2$。
- $dp_{u,i+i’,0} = dp_{u,i,0} + dp_{v,i’,0} + i \times i’$
- $dp_{u,i+i’,1} = dp_{u,i,1} + dp_{v,i’,1} + i \times i’ \times 2$
问题来了,这样的状态设计是 $O(n^2)$ 的。
观察到一个同色连通块的损失是 $O(n^2)$,这肯定不优,所以 $i$ 这一维只要开到 $\sqrt{n}$ 即可。
为什么卡空间为什么卡空间为什么卡空间为什么卡空间为什么卡空间
注意力惊人,注意到一个节点 $u$ 只会从它孩子计算 $dp$ 值,所以使用 vector
,一个节点的父节点计算完之后可以直接释放内存。
时间复杂度 $O(n \sqrt{n})$。
代码
学到了一种释放空间的方式:vector<int>().swap(a);
貌似 clear
并不能释放空间,只能清空。
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N = 2e5 + 15, M = N << 1;
const int INF = 0x3f3f3f3f3f3f3f3f;
int T, n;
int h[N], e[M], ne[M], idx = 0;
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
int sz[N];
vector<int> dp[N][2];
void DP(int u, int father) {
sz[u] = 1;
dp[u][0].push_back(INF), dp[u][1].push_back(INF);
dp[u][0].push_back(1), dp[u][1].push_back(2);
vector<int> g[2];
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == father) continue;
DP(v, u);
g[0].clear(), g[1].clear();
int limu = ceil(sqrt(sz[u])), limv = ceil(sqrt(sz[v])), lim = ceil(sqrt(sz[u] + sz[v]));
for (int j = 0; j <= lim; j++) g[0].push_back(INF), g[1].push_back(INF);
for (int j = limu; j >= 1; j--)
for (int k = limv; k >= 1; k--) {
g[0][j] = min(g[0][j], dp[u][0][j] + dp[v][1][k]);
g[1][j] = min(g[1][j], dp[u][1][j] + dp[v][0][k]);
if (j + k > lim) continue;
g[0][j + k] = min(g[0][j + k], dp[u][0][j] + dp[v][0][k] + j * k);
g[1][j + k] = min(g[1][j + k], dp[u][1][j] + dp[v][1][k] + j * k * 2);
}
dp[u][0] = g[0], dp[u][1] = g[1];
vector<int>().swap(dp[v][0]), vector<int>().swap(dp[v][1]);
sz[u] += sz[v];
}
}
signed main() {
scanf("%lld", &T);
while (T--) {
scanf("%lld", &n);
for (int i = 1; i <= n; i++) h[i] = -1; idx = 0;
for (int i = 1, u, v; i < n; i++) {
scanf("%lld%lld", &u, &v);
add(u, v), add(v, u);
}
vector<int>().swap(dp[1][0]), vector<int>().swap(dp[1][1]);
DP(1, 0);
int ans = INF;
for (int i = 1; i <= ceil(sqrt(n)); i++) ans = min({ans, dp[1][0][i], dp[1][1][i]});
printf("%lld\n", n * (n + 1) - ans);
}
return 0;
}