维护数列
解题思路
本题有 $6$ 个操作。
操作 $1$,在某一个位置插入一段数。
操作 $2$,将某一段数删掉
操作 $3$,将某一段数全部变成一个数
操作 $4$,翻转某一段数
操作 $5$,求某一段数的和
操作 $6$,求某个区间的最大子段和
可以发现,操作 $1$ 和操作 $2$ 能用 Splay 轻松实现,而操作 $3$ 和操作 $4$ 实际上都是区间修改,可以搭配懒标记实现,而操作 $5$ 我们只需要维护一个区间和的信息,在 $pushup$ 的时候用两个子节点的区间和进行更新即可。
操作 $6$ 则有点复杂,首先需要维护一个信息就是每个区间的最大子段和,然后我们要考虑怎么从两个子节点得出当前区间的最大子段和,首先考虑当前区间的最大子段和有哪几种情况,一种情况是最大子段和在左子区间中,此时当前区间的最大子段和就是左子区间的最大子段和,第二种情况是最大子段和在右子区间中,此时当前区间的最大子段和就是右子区间的最大子段和,第三种情况是最大子段和同时穿过左子区间和右子区间,也就是在中间部分,此时当前区间的最大子段和应该是左子区间的最大后缀和加上右子区间的最大前缀和。
综上所述,当前区间的最大子段和要从三种情况中取最大值,而在计算的过程中我们还要知道子区间的最大前缀和和最大后缀和,因此我们还要维护两个信息,就是每个区间的最大前缀和和最大后缀和。
我们既然维护了最大前缀和和最大后缀和,我们就也需要从子节点中得到当前节点的最大前缀和和最大后缀和。
考虑最大前缀和,当前区间的最大前缀和从两种情况得出,一种情况就是最大前缀和只在左子区间中,那么当前区间的最大前缀和就是左子区间的最大前缀和。另一种情况就是最大前缀和穿过左子区间和右子区间,那么此时当前区间的最大前缀和应该是左子区间的区间总和再加上右子区间的最大前缀和。而最大前缀和就是从这两种情况中取最大值即可。最大后缀和则同理。
可以发现我们要想得到最大前缀和,就还需要维护一个区间和的信息,而当前区间的区间和我们可以直接通过两个子区间的区间和相加得到。
分析到一步,我们就可以维护每个区间的最大子段和,到此以上六个操作都能够用 Splay 实现。要想实现删除在排名第 $k$ 的点后面插入一段数和删除一段数,我们需要能够查询排名第 $k$ 的点的位置,为此我们还需要维护每个节点所在子树中的节点数量 $cnt$,而要想实现区间修改和区间翻转,因此我们还需要额外维护两个懒标记。我们规定,当前节点拥有一个懒标记,则它当前的信息都是已经执行完这个懒标记之后的状态。因此我们也可以认为当前节点的懒标记为是否对两个子节点进行操作。
本题最开始给了我们一个序列,意味着最开始 Splay 中是有元素的,因此我们最开始需要将序列建到 Splay 中去,这里最正宗的建 Splay 的方式就是递归着建,我们取序列的中间点为根节点,中间点左边的数沿着根节点的左子树递归去建,中间点右边的数沿着根节点的右子树递归去建,这样最终建出来的 Splay 能保证是一个完全平衡的二叉树,运行效率最高。
注意,本题虽然保证了 Splay 中节点个数不会超过 $500000$,但是由于插入删除的操作次数非常的多,所以可能会用到两三百万个节点,非常大,因此这里还需要用到 Splay 的一个内存回收的机制。
这个内存回收的机制就是,由于我们可能执行很多次删除操作,每一次都删除其中连续的一段,那么这一段的下标在后面都不会被使用,非常的浪费,我们可以用数组来存储当前可用的下标有哪些,最开始数组中存储 $1 \sim n$,表示最开始所有下标都可以用,每当我们删除一段后,我们将这一段删除的下标重新插入这个垃圾回收站中,每次我们要给一个新节点分配一个下标时,则可以直接从垃圾回收站中取出一个下标拿来使用。
综上所述,就是本题的整个思路,本题基本上实现了大部分 Splay 的常用操作,是一个非常锻炼对 Splay 的理解程度的题,解决本题,基本上对 Splay 的理解就过关了
C++ 代码
#include <iostream>
#include <cstring>
using namespace std;
const int N = 500010, INF = 0x3f3f3f3f;
int n, m;
struct Node
{
int s[2], p, v; //子节点下标、父节点下标、数值
int rev, same; //子节点是否需要翻转、子节点的所有数值是否和根节点相同
int cnt, sum; //节点个数、节点的数值和
int ms, ls, rs; //最大子段和、最大前缀和、最大后缀和
void init(int _v, int _p) //初始化节点
{
s[0] = s[1] = 0;
v = _v, p = _p;
rev = same = 0;
cnt = 1, sum = v;
ms = v, ls = rs = max(0, v); //最大子段和至少有一个数
}
}tr[N]; //Splay
int root; //Splay 的根节点
int node[N], tt; //垃圾回收站(栈),存储所有空的下标
int a[N];
void pushup(int x) //通过子节点的信息计算当前节点的信息
{
Node &u = tr[x], &l = tr[u.s[0]], &r = tr[u.s[1]];
u.cnt = l.cnt + r.cnt + 1;
u.sum = l.sum + r.sum + u.v;
u.ls = max(l.ls, l.sum + u.v + r.ls);
u.rs = max(r.rs, r.sum + u.v + l.rs);
u.ms = max(max(l.ms, r.ms), l.rs + u.v + r.ls);
}
void pushdown(int x) //将当前节点的懒标记下传给子节点
{
Node &u = tr[x], &l = tr[u.s[0]], &r = tr[u.s[1]];
if(u.same)
{
u.same = u.rev = 0;
if(u.s[0]) l.same = 1, l.v = u.v, l.sum = l.v * l.cnt;
if(u.s[1]) r.same = 1, r.v = u.v, r.sum = r.v * r.cnt;
if(u.v > 0)
{
if(u.s[0]) l.ms = l.ls = l.rs = l.sum;
if(u.s[1]) r.ms = r.ls = r.rs = r.sum;
}
else
{
if(u.s[0]) l.ms = l.v, l.ls = l.rs = 0;
if(u.s[1]) r.ms = r.v, r.ls = r.rs = 0;
}
}
else if(u.rev)
{
u.rev = 0, l.rev ^= 1, r.rev ^= 1;
swap(l.ls, l.rs), swap(r.ls, r.rs);
swap(l.s[0], l.s[1]), swap(r.s[0], r.s[1]);
}
}
void rotate(int x) //左旋、右旋
{
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int x, int k) //将节点 x 转到节点 k 的下面
{
while(tr[x].p != k)
{
int y = tr[x].p, z = tr[y].p;
if(z != k)
{
if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x);
else rotate(y);
}
rotate(x);
}
if(!k) root = x;
}
int get_k(int k) //返回数值第 k 大的节点的下标
{
int u = root;
while(u)
{
pushdown(u);
if(tr[tr[u].s[0]].cnt >= k) u = tr[u].s[0];
else if(tr[tr[u].s[0]].cnt + 1 == k) return u;
else k -= tr[tr[u].s[0]].cnt + 1, u = tr[u].s[1];
}
return -1;
}
int build(int l, int r, int p) //将 [l ~ r] 建成 Splay,接在 p 的下面,并返回该 Splay 的根节点下标
{
int mid = l + r >> 1;
int u = node[tt--]; //当前 Splay 的根节点
tr[u].init(a[mid], p); //初始化当前 Splay 的根节点
if(l < mid) tr[u].s[0] = build(l, mid - 1, u); //递归建立左子树
if(r > mid) tr[u].s[1] = build(mid + 1, r, u); //递归建立右子树
pushup(u);
return u;
}
void dfs(int u) //将当前节点为根节点的子树删掉,并将删掉的下标放回垃圾回收站
{
if(tr[u].s[0]) dfs(tr[u].s[0]); //如果当前节点有左子节点,递归删除左子树
if(tr[u].s[1]) dfs(tr[u].s[1]); //如果当前节点有右子节点,递归删除右子树
node[++tt] = u; //将当前节点放入垃圾回收站
}
int main()
{
for(int i = 1; i < N; i++) node[++tt] = i; //最开始将所有空的下标存入垃圾回收站
scanf("%d%d", &n, &m);
/*
如果某些节点没有子节点,此时子节点的下标为 0,为了避免从 0 号节点更新错误的信息,
将 0 号节点的最大子段和设为 -INF
*/
tr[0].ms = -INF;
//插入删除一段序列时需要查询前驱和后继,在开头和结尾设置两个哨兵防止边界问题
a[0] = a[n + 1] = -INF;
for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
root = build(0, n + 1, 0); //初始化 Splay
char op[20];
while(m--)
{
scanf("%s", op);
if(!strcmp(op, "INSERT")) //插入一段序列
{
int pos, tot;
scanf("%d%d", &pos, &tot);
for(int i = 0; i < tot; i++) scanf("%d", &a[i]); //接收要插入的序列
int l = get_k(pos + 1), r = get_k(pos + 2); //找到要插入序列的前驱和后继
splay(l, 0), splay(r, l); //将 l 转到根节点,将 r 转到 l 的下面
int u = build(0, tot - 1, r); //将要插入序列建成 Splay,并接在 r 下面
tr[r].s[0] = u;
pushup(r), pushup(l);
}
else if(!strcmp(op, "DELETE")) //删除一段序列
{
int pos, tot;
scanf("%d%d", &pos, &tot);
int l = get_k(pos), r = get_k(pos + 1 + tot); //找到要删除序列的前驱和后继
splay(l, 0), splay(r, l); //将 l 转到根节点,将 r 转到 l 的下面
dfs(tr[r].s[0]); //将 r 的左子树删掉
tr[r].s[0] = 0;
pushup(r), pushup(l);
}
else if(!strcmp(op, "MAKE-SAME")) //将某段序列修改成同一个数
{
int pos, tot, c;
scanf("%d%d%d", &pos, &tot, &c);
int l = get_k(pos), r = get_k(pos + tot + 1); //找到要修改序列的前驱和后缀
splay(l, 0), splay(r, l); //将 l 转到根节点,将 r 转到 l 的下面
Node &u = tr[tr[r].s[0]]; //找到要修改序列所在子树的根节点
//修改当前序列的信息
u.same = 1, u.v = c, u.sum = tot * c;
if(c > 0) u.ms = u.ls = u.rs = u.sum;
else u.ms = c, u.ls = u.rs = 0;
pushup(r), pushup(l);
}
else if(!strcmp(op, "REVERSE")) //将某段序列翻转
{
int pos, tot;
scanf("%d%d", &pos, &tot);
int l = get_k(pos), r = get_k(pos + tot + 1); //找到要翻转序列的前驱和后继
splay(l, 0), splay(r, l); //将 l 转到根节点,将 r 转到 l 的下面
Node &u = tr[tr[r].s[0]]; //找到要翻转序列所在子树的根节点
//修改当前序列的信息
u.rev ^= 1;
swap(u.ls, u.rs);
swap(u.s[0], u.s[1]);
pushup(r), pushup(l);
}
else if(!strcmp(op, "GET-SUM")) //求某段序列的和
{
int pos, tot;
scanf("%d%d", &pos, &tot);
int l = get_k(pos), r = get_k(pos + tot + 1); //找到要求和序列的前驱和后继
splay(l, 0), splay(r, l); //将 l 转到根节点,将 r 转到 l 的下面
printf("%d\n", tr[tr[r].s[0]].sum);
}
else //求整个序列的最大子段和
printf("%d\n", tr[root].ms);
}
return 0;
}