cf 2400*
树链剖分+线段树维护最大(最小)子段和 O(nlogn2)
(树上倍增写的不熟之后再补qwq)
最大子段和模板 https://www.acwing.com/problem/content/246/
树链剖分模板 https://www.acwing.com/problem/content/2570/
思路: 题目要求的是路径总和恰好为k,但是数组元素只由1和-1组成,所以只需要求最大(最小)子数组和,
查看k是否被最大(最小)子数组和包含就行了
(min_sum <= k <= max_sum)
首先用树链剖分的基本操作将树分成链转变成数组,然后用线段树维护子段和即可,
不过在树上合并子段和时需要考虑两个子段的顺序问题
图1 合并是正确的
可以把两边想象成两个数组,要合并两个数组需要将他们按顺序合并起来即(左边数组的右边界和右边数组的左边界合并)
而不是左边界和左边界合并,右边界和右边界合并
图1
图2
java 代码
import java.io.*;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.StringTokenizer;
public class Main {
static int N = 200010;
static int[] id = new int[N], fa = new int[N], deep = new int[N], top = new int[N];
static int[] size = new int[N], son = new int[N], b = new int[N], w = new int[N];
static int[] h = new int[N], e = new int[2 * N], ne = new int[2 * N];
static Node[] tr = new Node[4 * N];
static int idx = 0, time = 0;
static void add(int a, int b) {
e[idx] = b;
ne[idx] = h[a];
h[a] = idx++;
}
//树链剖分 dfs1
static void dfs1(int u, int dep, int father) {
size[u] = 1;
deep[u] = dep;
fa[u] = father;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs1(j, dep + 1, u);
size[u] += size[j];
if (son[u] == 0 || size[j] > size[son[u]]) son[u] = j;
}
}
//还原数组的值
static void dfs3(int u, int dep, int father) {
size[u] = 0;
deep[u] = 0;
fa[u] = 0;
son[u] = 0;
top[u] = 0;
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs3(j, dep + 1, u);
}
}
//树链剖分 dfs2
static void dfs2(int u, int tp) {
id[u] = ++time;
top[u] = tp;
b[time] = w[u];
if (son[u] != 0) dfs2(son[u], tp);
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
if (j != fa[u] && son[u] != j) {
dfs2(j, j);
}
}
}
//求顶点x当顶点y的最大(最小)子段和
static Node valroot(int x, int y) {
Node x1 = null;
Node y1 = null;
while (top[x] != top[y]) {
if (deep[top[x]] > deep[top[y]]) {
Node ans = query(1, id[top[x]], id[x]);
if (x1 == null)
x1 = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
else {
Node now = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
Node b = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
push(now, b, x1);
x1 = now;
}
x = fa[top[x]];
} else {
Node ans = query(1, id[top[y]], id[y]);
if (y1 == null)
y1 = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
else {
Node now = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
Node b = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
push(now, b, y1);
y1 = now;
}
y = fa[top[y]];
}
}
Node c = new Node(0, 0, 0, 0, 0, 0, 0, 0, 0);
if (x1 != null && y1 == null) {
if (deep[x] < deep[y]) {
Node ans = query(1, id[x], id[y]);
y1 = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
Node res_x1 = new Node(x1.l, x1.r, x1.sum, x1.rmax, x1.lmax, x1.tmax, x1.rmin, x1.lmin, x1.tmin);
push(c, res_x1, y1);
} else {
Node ans = query(1, id[y], id[x]);
Node now = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
Node b = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
push(now, b, x1);
c = now;
}
} else if (y1 != null && x1 == null) {
if (deep[y] < deep[x]) {
Node ans = query(1, id[y], id[x]);
x1 = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
Node res_y1 = new Node(y1.l, y1.r, y1.sum, y1.rmax, y1.lmax, y1.tmax, y1.rmin, y1.lmin, y1.tmin);
push(c, res_y1, x1);
} else {
Node ans = query(1, id[x], id[y]);
Node now = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
Node b = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
push(now, b, y1);
c = now;
}
} else if (y1 != null && x1 != null) {
if (deep[x] < deep[y]) {
Node ans = query(1, id[x], id[y]);
Node now = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
Node b = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
push(now, b, y1);
y1 = now;
} else {
Node ans = query(1, id[y], id[x]);
Node now = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
Node b = new Node(ans.l, ans.r, ans.sum, ans.lmax, ans.rmax, ans.tmax, ans.lmin, ans.rmin, ans.tmin);
push(now, b, x1);
x1 = now;
}
Node res_x1 = new Node(x1.l, x1.r, x1.sum, x1.rmax, x1.lmax, x1.tmax, x1.rmin, x1.lmin, x1.tmin);
push(c, res_x1, y1);
} else {
if (deep[x] < deep[y]) {
int tep = x;
x = y;
y = tep;
}
c = query(1, id[y], id[x]);
}
return c;
}
static class Node {
int l, r;
int sum;
int lmax, rmax, tmax;
int lmin, rmin, tmin;
public Node(int l, int r) {
this.l = l;
this.r = r;
}
public Node(int l, int r, int s, int l1, int r1, int t1, int l2, int r2, int t2) {
this.l = l;
this.r = r;
sum = s;
lmax = l1;
rmax = r1;
tmax = t1;
lmin = l2;
rmin = r2;
tmin = t2;
}
}
//线段树最大(最小)子段和操作
public static void push(Node u, Node l, Node r) {
u.sum = l.sum + r.sum;
u.lmax = Math.max(l.lmax, l.sum + r.lmax);
u.rmax = Math.max(r.rmax, r.sum + l.rmax);
u.tmax = Math.max(Math.max(l.tmax, r.tmax), l.rmax + r.lmax);
u.lmin = Math.min(l.lmin, l.sum + r.lmin);
u.rmin = Math.min(r.rmin, r.sum + l.rmin);
u.tmin = Math.min(Math.min(l.tmin, r.tmin), l.rmin + r.lmin);
}
public static void pushup(int u) {
push(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
public static void build(int u, int l, int r) {
if (l == r) tr[u] = new Node(l, r, w[l], w[l], w[l], w[l], w[l], w[l], w[l]);
else {
int mid = (l + r) >> 1;
tr[u] = new Node(l, r);
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
public static void modify(int u, int x, int v) {
if (tr[u].l == x && tr[u].r == x) tr[u] = new Node(x, x, v, v, v, v, v, v, v);
else {
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
}
public static Node query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
else {
int mid = (tr[u].l + tr[u].r) >> 1;
if (r <= mid) return query(u << 1, l, r);
else if (l > mid) return query(u << 1 | 1, l, r);
else {
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
Node res = new Node(0, 0);
push(res, left, right);
return res;
}
}
}
public static void main(String[] args) throws IOException {
FastScanner f = new FastScanner();
PrintWriter w1 = new PrintWriter(System.out);
int T = f.nextInt();
int sum = 0;
while (T-- > 0) {
int n = f.nextInt();
List<int[]> a = new ArrayList<>();
Arrays.fill(h, -1);
idx = time = 0;
int idx = 2;
int pre = sum;
for (int i = 1; i <= n; i++) {
String str = f.nextString();
if (str.equals("+")) {
int x = f.nextInt();
int y = f.nextInt();
add(x, idx);
add(idx, x);
a.add(new int[]{0, idx, y, x});
idx++;
} else if (str.equals("?")) {
int x = f.nextInt();
int y = f.nextInt();
int k = f.nextInt();
a.add(new int[]{1, x, y, k});
sum++;
}
}
dfs1(1, 1, -1);
dfs2(1, 1);
build(1, 1, idx - 1);
modify(1, id[1], 1);
for (int[] c : a) {
int op = c[0];
if (op == 0) {
int x = c[1];
int add = c[2];
modify(1, id[x], add);
} else {
int x = c[1];
int y = c[2];
int k = c[3];
Node now = valroot(x, y);
long min = now.tmin, max = now.tmax;
if (k == 0) w1.println("YES");
else {
if (k >= min && k <= max) w1.println("YES");
else w1.println("NO");
}
}
}
dfs3(1, 1, -1);
}
w1.flush();
}
private static class FastScanner {
final private int BUFFER_SIZE = 1 << 16;
private DataInputStream din;
private byte[] buffer;
private int bufferPointer, bytesRead;
private FastScanner() throws IOException {
din = new DataInputStream(System.in);
buffer = new byte[BUFFER_SIZE];
bufferPointer = bytesRead = 0;
}
private short nextShort() throws IOException {
short ret = 0;
byte c = read();
while (c <= ' ') c = read();
boolean neg = (c == '-');
if (neg) c = read();
do ret = (short) (ret * 10 + c - '0');
while ((c = read()) >= '0' && c <= '9');
if (neg) return (short) -ret;
return ret;
}
private int nextInt() throws IOException {
int ret = 0;
byte c = read();
while (c <= ' ') c = read();
boolean neg = (c == '-');
if (neg) c = read();
do ret = ret * 10 + c - '0';
while ((c = read()) >= '0' && c <= '9');
if (neg) return -ret;
return ret;
}
public long nextLong() throws IOException {
long ret = 0;
byte c = read();
while (c <= ' ') c = read();
boolean neg = (c == '-');
if (neg) c = read();
do ret = ret * 10 + c - '0';
while ((c = read()) >= '0' && c <= '9');
if (neg) return -ret;
return ret;
}
private char nextChar() throws IOException {
byte c = read();
while (c <= ' ') c = read();
return (char) c;
}
private String nextString() throws IOException {
StringBuilder ret = new StringBuilder();
byte c = read();
while (c <= ' ') c = read();
do {
ret.append((char) c);
} while ((c = read()) > ' ');
return ret.toString();
}
private void fillBuffer() throws IOException {
bytesRead = din.read(buffer, bufferPointer = 0, BUFFER_SIZE);
if (bytesRead == -1) buffer[0] = -1;
}
private byte read() throws IOException {
if (bufferPointer == bytesRead) fillBuffer();
return buffer[bufferPointer++];
}
}
}