题目描述
难度分:2000
输入n(1≤n≤2×105),m(1≤m≤2×105)和长为n的数组 a(1≤a[i]≤103),表示一棵树每个节点的初始点权。
然后输入一棵树的n−1条边,节点编号从1到n,根节点为1。
然后输入m个询问,格式如下:
-
1 x val
:对于以x为根的子树,把根节点的点权增加val,根节点儿子的点权增加 −val,根节点儿子的儿子的点权增加val,依此类推,val和−val交替,直到叶子。其中1≤val≤103。 -
2 x
:输出节点x的点权。
输入样例
5 5
1 2 1 1 2
1 2
1 3
2 4
2 5
1 2 3
1 1 2
2 1
2 2
2 4
输出样例
3
3
0
算法
差分数组+树状数组
比较显然的一点是在以x为根节点的子树中,如果节点深度与x的深度奇偶性相同,操作1就是加val,否则就是减val。层的奇偶性比较好获得,至于要定位到以某个节点x为根的子树所包含的所有节点,可以先对整棵树进行一次DFS
,求一个DFS序(DFS序可以使属于某一个子树的所有节点在一个连续的子数组内,可以将子树上的操作转化为在区间上的操作),将节点x的时间戳,以及以x为根的子树中最后一个节点的时间戳,还有x节点的符号分别存入结构体node[i]的l、r、sgn三个属性中。为了方便,根节点的深度为1,规定深度为奇数的节点对应区间加val,深度为偶数的节点对应区间减val,也就是说偶数深度的节点x满足node[x].sgn=−1,奇数深度的节点x满足node[x].sgn=1。
-
对于操作1,先找到节点x在DFS序上的范围[l,r],然后在这个区间上进行加val×node[x].sgn的操作,可以用树状数组维护一个差分数组。
-
对于操作2,直接用树状数组求前缀和,可以得到节点x上操作的增量gain,注意这个增量要乘上x节点的符号node[x].sgn。此时节点x的点权就应该是a[x]+gain×node[x].sgn。
这个分奇偶操作的地方还是很抽象的,思维跳跃性比较强,纯文字不太容易说明白。但是画个图模拟一下发现确实能凑出正确答案,只能说算法实在太妙了。
复杂度分析
时间复杂度
DFS
求DFS序本质上就是遍历一遍树,时间复杂度为O(n)。对于每个询问,都要在O(log2n)的时间复杂度下操作树状数组,m个询问的时间复杂度为O(mlog2n)。因此,整个算法的时间复杂度为O(n+mlog2n)。
空间复杂度
树状数组的空间开销为O(n);求DFS序需要递归,空间开销也是O(n);由于整张图是一个无向树,因此边的数量和节点的数量是同级别的,邻接表的空间开销为O(n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 200010;
vector<int> graph[N];
int n, m, ts, a[N];
struct Node {
int l, r, sgn;
} node[N];
int tree[N];
void dfs(int u, int fa, int sgn) {
ts++;
node[u].sgn = sgn;
node[u].l = ts;
for(int v: graph[u]) {
if(v == fa) continue;
dfs(v, u, -sgn);
}
node[u].r = ts;
}
int lowbit(int x) {
return x & -x;
}
void add(int x, int val) {
while(x <= n) {
tree[x] += val;
x += lowbit(x);
}
}
int query(int x) {
int res = 0;
while(x > 0) {
res += tree[x];
x -= lowbit(x);
}
return res;
}
void init() {
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
graph[i].clear();
tree[i] = 0;
}
}
int main() {
scanf("%d%d", &n, &m);
init();
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
graph[u].push_back(v);
graph[v].push_back(u);
}
ts = 0;
dfs(1, 0, 1);
for(int i = 1; i <= m; i++) {
int op, x, val;
scanf("%d%d", &op, &x);
if(op == 1) {
scanf("%d", &val);
add(node[x].l, val*node[x].sgn);
add(node[x].r + 1, -val*node[x].sgn);
}else {
printf("%d\n", a[x] + query(node[x].l)*node[x].sgn);
}
}
return 0;
}