题目描述
难度分:2100
输入n(1≤n≤4×105),m(1≤m≤4×105)和长为n的数组 a(1≤a[i]≤60),表示一棵n个节点的无向树,以及每个节点的颜色a[i]。
然后输入这棵树的n−1条边(编号从1到n)。根节点的编号是1。
然后输入m个操作:
1 v c
:把以v为根的子树中的所有节点的颜色改成c(1≤c≤60)。2 v
:输出以v为根的子树中,有多少种不同的颜色。
输入样例1
7 10
1 1 1 1 1 1 1
1 2
1 3
1 4
3 5
3 6
3 7
1 3 2
2 1
1 4 3
2 1
1 2 5
2 1
1 6 4
2 1
2 2
2 3
输出样例1
2
3
4
5
1
2
输入样例2
23 30
1 2 2 6 5 3 2 1 1 1 2 4 5 3 4 4 3 3 3 3 3 4 6
1 2
1 3
1 4
2 5
2 6
3 7
3 8
4 9
4 10
4 11
6 12
6 13
7 14
7 15
7 16
8 17
8 18
10 19
10 20
10 21
11 22
11 23
2 1
2 5
2 6
2 7
2 8
2 9
2 10
2 11
2 4
1 12 1
1 13 1
1 14 1
1 15 1
1 16 1
1 17 1
1 18 1
1 19 1
1 20 1
1 21 1
1 22 1
1 23 1
2 1
2 5
2 6
2 7
2 8
2 9
2 10
2 11
2 4
输出样例2
6
1
3
3
2
1
2
3
5
5
1
2
2
1
1
1
2
3
算法
DFS序+线段树
这个题是对子树操作,比较容易想到利用树的DFS序将子树操作转化为区间操作,因此先通过DFS做这个预处理。
然后利用线段树来进行区间染色就可以了,需要注意的是空间消耗,本来我看最多只有60种颜色,对每个节点都开一个哈希表,但是样例就MLE
了。因此,考虑状态压缩,每个区间的颜色用一个60位二进制数mask来表示,一旦mask的第i位为1,说明这个区间对应的子树有i这个颜色。查询时返回区间对应的状态mask,里面有多少个1就有多少种颜色。
复杂度分析
时间复杂度
求DFS序需要遍历整棵树一遍,时间复杂度为O(n)。后续的m个操作每次都要对O(n)级别的线段树进行区间推平和区间查询操作,每个操作的时间复杂度均为O(log2n),因此时间复杂度为O(mlog2n)。
综上,整个算法的时间复杂度为O(n+mlog2n)。
空间复杂度
以节点u为根的子树对应的区间[l,r]用low[u]=l和high[u]=r两个数组存储,n个节点的空间消耗是O(n)的。树的邻接表、线段树的空间复杂度也是O(n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 400010;
int n, m, ts, c[N], w[N], low[N], high[N];
vector<int> graph[N];
class SegmentTree {
public:
struct Tag {
LL add;
Tag() {
add = 0;
}
};
struct Info {
int l, r;
LL mask;
Tag lazy;
Info() {}
Info(int left, int right): l(left), r(right) {}
} tr[N<<2];
explicit SegmentTree() {}
void build(int u, int l, int r) {
if(l == r) {
tr[u] = Info(l, r);
tr[u].mask = 1LL<<w[l];
return;
}
int mid = (l + r) >> 1;
build(lc(u), l, mid);
build(rc(u), mid + 1, r);
pushup(u);
}
void modify(int l, int r, LL d) {
modify(1, l, r, d);
}
Info query(int l, int r) {
return query(1, l, r);
}
private:
int lc(int u) {
return u<<1;
}
int rc(int u) {
return u<<1|1;
}
void pushup(int u) {
tr[u] = merge(tr[lc(u)], tr[rc(u)]);
}
void pushdown(int u) {
if(not_null(tr[u].lazy)) {
down(u);
clear_lazy(tr[u].lazy); // 标记下传后要清空
}
}
void modify(int u, int l, int r, LL d) {
if(l <= tr[u].l && tr[u].r <= r) {
set(u, d);
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(mid >= l) modify(lc(u), l, r, d);
if(mid < r) modify(rc(u), l, r, d);
pushup(u);
}
Info query(int u, int l, int r) {
if(l <= tr[u].l && tr[u].r <= r) return tr[u];
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if(r <= mid) {
return query(u<<1, l, r);
}else if(mid < l) {
return query(u<<1|1, l, r);
}else {
return merge(query(u<<1, l, r), query(u<<1|1, l, r));
}
}
Info merge(const Info& lchild, const Info& rchild) {
Info info;
info.l = lchild.l, info.r = rchild.r;
info.mask = lchild.mask|rchild.mask;
return info;
}
// modify操作到不能递归时,设置节点的属性值
void set(int u, int d) {
tr[u].mask = 1LL<<d;
tr[u].lazy.add = 1LL<<d;
}
// 下传标记的规则
void down(int u) {
int mid = (tr[u].l + tr[u].r) >> 1;
tr[lc(u)].lazy.add = tr[u].lazy.add;
tr[rc(u)].lazy.add = tr[u].lazy.add;
tr[lc(u)].mask = tr[u].lazy.add;
tr[rc(u)].mask = tr[u].lazy.add;
}
// 判断标记是否为空的规则
bool not_null(Tag& lazy) {
return lazy.add != 0;
}
// 清空标记的规则
void clear_lazy(Tag& lazy) {
lazy.add = 0;
}
};
void dfs(int u, int fa) {
low[u] = ++ts;
w[ts] = c[u];
for(int v: graph[u]) {
if(fa == v) continue;
dfs(v, u);
}
high[u] = ts;
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) {
graph[i].clear();
scanf("%d", &c[i]);
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
ts = 0;
dfs(1, 0);
SegmentTree seg;
seg.build(1, 1, n);
for(int i = 1; i <= m; i++) {
int tk;
scanf("%d", &tk);
if(tk == 1) {
int u, ck;
scanf("%d%d", &u, &ck);
int l = low[u], r = high[u];
seg.modify(l, r, ck);
}else {
int u;
scanf("%d", &u);
int l = low[u], r = high[u];
LL res = seg.query(l, r).mask;
printf("%d\n", __builtin_popcountll(res));
}
}
return 0;
}