终于遇到了一道有资格评为困难的题。
第一眼就知道肯定要枚举这两个区间是什么,然后发现非常难优化。
所以当我没说。
直接重开,改成枚举连续的值域段是什么。
于是就变成了求多少个值域区间,使得它们在 $a$ 序列中出现的区间可以表示成两个不相交的区间的并。
即区间个数 $\leq 2$,并且区间长度 $\gt 1$。
考虑从小到大扫描右端点,并通过某种方法维护 $[1,r)$ 作为左端点与 $r$ 形成的值域区间在 $a$ 序列中对应多少个分开的区间。
这种问题一般有两种思考方向:“点减边”或者 $\max - \min = r - l$。
考虑前者。
即插入 $r$ 的时候,假设它是一个单独的连通块,那么给 $l \in [1,r]$ 全部 $+1$。
然后再分别看它在 $a$ 序列中左右的元素,如果已经插入过,说明可以合并成一个连通块,那么就把对应的 $l$ 前缀 $-1$。
询问 $r$ 为右端点即查询 $l \in [1,r)$ 的区间个数 $\leq 2$ 的数量,也可以是 $l \in [1,r]$ 的答案减一(减去 $r$ 自己单独一个区间的)。
这个东西显然可以线段树维护区间最小值,以及最小值和最小值加一出现次数。
这里并不需要维护次小值,因为只需要查询 $\leq 2$,所以仅存储最小值加一的出现次数即可。
#include <bits/stdc++.h>
using namespace std;
const int N = 3e5 + 15;
int n, a[N], pos[N];
bool st[N];
long long ans = 0;
struct Tree {
int l, r;
int Min; //最小;最小值+1
int fi, se;
int flag;
} tr[N << 2];
void pushup(int u) {
tr[u].Min = min(tr[u << 1].Min, tr[u << 1 | 1].Min);
tr[u].fi = tr[u].se = 0;
if (tr[u << 1].Min == tr[u].Min) tr[u].fi += tr[u << 1].fi;
if (tr[u << 1 | 1].Min == tr[u].Min) tr[u].fi += tr[u << 1 | 1].fi;
if (tr[u << 1].Min == tr[u].Min) tr[u].se += tr[u << 1].se;
if (tr[u << 1 | 1].Min == tr[u].Min) tr[u].se += tr[u << 1 | 1].se;
if (tr[u << 1].Min == tr[u].Min + 1) tr[u].se += tr[u << 1].fi;
if (tr[u << 1 | 1].Min == tr[u].Min + 1) tr[u].se += tr[u << 1 | 1].fi;
}
void pushdown(int u) {
if (tr[u].flag) {
tr[u << 1].flag += tr[u].flag;
tr[u << 1 | 1].flag += tr[u].flag;
tr[u << 1].Min += tr[u].flag;
tr[u << 1 | 1].Min += tr[u].flag;
tr[u].flag = 0;
}
}
void build(int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
if (l == r) {
tr[u].Min = 0;
tr[u].fi = 1, tr[u].se = 0;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void change(int u, int l, int r, int d) {
if (r < tr[u].l || l > tr[u].r) return;
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].Min += d, tr[u].flag += d;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) change(u << 1, l, r, d);
if (r > mid) change(u << 1 | 1, l, r, d);
pushup(u);
}
int query(int u, int l, int r) {
if (r < tr[u].l || l > tr[u].r) return 0;
if (tr[u].l >= l && tr[u].r <= r) {
if (tr[u].Min <= 1) return tr[u].fi + tr[u].se;
else if (tr[u].Min <= 2) return tr[u].fi;
else return 0;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if (l <= mid) res += query(u << 1, l, r);
if (r > mid) res += query(u << 1 | 1, l, r);
return res;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]), pos[a[i]] = i;
build(1, 1, n);
for (int i = 1; i <= n; i++) {
change(1, 1, i, 1);
if (st[pos[i] - 1]) change(1, 1, a[pos[i] - 1], -1);
if (st[pos[i] + 1]) change(1, 1, a[pos[i] + 1], -1);
st[pos[i]] = 1;
ans += query(1, 1, i) - 1; //自己单独一段
}
printf("%lld\n", ans);
return 0;
}