题目描述
难度分:1500
输入n(3≤n≤5000)和一棵n个节点的无向树的n−1条边。节点编号从1开始。
一开始每个节点都是白色。你需要把某些(不是全部)节点染色红色或者蓝色。要求:
- 至少有一个点是红色,至少有一个点是蓝色。
- 红色节点不能和蓝色节点相邻。
设(a,b)=(红色节点数,蓝色节点数)。你需要最大化a+b。输出有多少种不同的(a,b),记作k。然后输出这k种(a,b),按照a升序。
输入样例1
5
1 2
2 3
3 4
4 5
输出样例1
3
1 3
2 2
3 1
输入样例2
10
1 2
2 3
3 4
5 6
6 7
7 4
8 9
9 10
10 4
输出样例2
6
1 8
2 7
3 6
6 3
7 2
8 1
算法
01
背包
首先要想到的一点就是a+b的最大值就是n−1,即只有一个节点未被染色,与这个节点相邻的都是些红色或者蓝色的连通块。所以先对整棵树做一遍DFS
,预处理出一个sz数组,sz[u]表示以1为整棵树的根,u为根的情况下,u这棵子树中节点的个数。然后再做一遍DFS
求答案,尝试让各个节点作为唯一的那个白色节点。
对于一个节点u,它的上面有up=n−sz[u]个节点,下面有大小为sz[v1],sz[v2],…,sz[vk]的k个子树。有两种情况u可以作为白色节点:
- up=0且k>1,这样不会所有连通块都是一个颜色。
- up≠0且k>0,这样也可以保证至少有两个连通块,可以染不同的颜色。
在有解的情况下,将sz[v1],sz[v2],…,sz[vk],up中的非零元素放入到一个数组blocks中,在这个数组上做容量为n的01
背包就可以了。
状态定义
dp[v]表示能否从blocks中凑出累加和v,初始化dp[0]=true
。
状态转移
外层循环遍历blocks,内层倒序遍历n到blocks[i]。如果dp[v−blocks[i]]=true
,那么dp[v]=true
。
DP
做完之后遍历i∈[1,n−2],只要dp[i]=true
,就把(i,n−1−i)加入到答案中。由于答案要求按照a排序,所以让答案是一个有序集合省去最后整体排序。
复杂度分析
时间复杂度
对整棵树DFS
的时间复杂度是O(n)。对于每个节点,需要对其邻居节点进行01
背包,宏观上来看其实就是对所有节点进行容量为n的01
背包,DP
的时间复杂度应该是O(n2)。每次将(a,b)加入到答案的时候是O(log2n)的,(a,b)的对数是O(n)的,所以插入答案的时间复杂度为O(nlog2n)。
整个算法的时间复杂度瓶颈还是在01
背包,时间复杂度为O(n2)。
空间复杂度
树的邻接表空间消耗是O(n);存储每个节点作为根时子树大小的sz数组也是O(n)的;答案表中的二元组个数是O(n)级别的,空间复杂度还是O(n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> PII;
const int N = 5010;
vector<int> graph[N];
int n, sz[N];
set<PII> st;
void dfs1(int u, int fa) {
sz[u] = 1;
for(int v: graph[u]) {
if(v == fa) continue;
dfs1(v, u);
sz[u] += sz[v];
}
}
void dfs2(int u, int fa) {
// 将u染成白色,可以得到上面的部分,和u的若干子树这些连通块
int up = n - sz[u];
vector<int> blocks;
for(int v: graph[u]) {
if(v == fa) continue;
dfs2(v, u);
blocks.push_back(sz[v]);
}
if(up == 0) {
// 上面的部分没有,但是有多个子节点
int cnt = blocks.size();
if(cnt > 1) {
vector<bool> dp(n + 1);
dp[0] = true;
for(int i = 1; i <= cnt; i++) {
for(int j = n - 2; j >= blocks[i - 1]; j--) {
dp[j] = dp[j] || dp[j - blocks[i - 1]];
}
}
for(int i = 1; i <= n - 2; i++) {
if(dp[i]) {
st.insert({i, n - 1 - i});
}
}
}
}else {
// 上面的部分存在,也有子节点
if(!blocks.empty()) {
blocks.push_back(up);
int cnt = blocks.size();
vector<bool> dp(n + 1);
dp[0] = true;
for(int i = 1; i <= cnt; i++) {
for(int j = n - 2; j >= blocks[i - 1]; j--) {
dp[j] = dp[j] || dp[j - blocks[i - 1]];
}
}
for(int i = 1; i <= n - 2; i++) {
if(dp[i]) {
st.insert({i, n - 1 - i});
}
}
}
}
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
graph[i].clear();
}
st.clear();
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 0);
printf("%d\n", (int)st.size());
for(auto&pir: st) {
printf("%d %d\n", pir.first, pir.second);
}
return 0;
}