生成魔咒问题
解题思路
本题每次在字符串末尾加上一个字符,然后统计一下字符串中有多少不同的子串。
假设不看修改操作,对于一个字符串,如果去统计不同子串的数量,首先枚举不同的起点,然后再枚举不同的终点。这里可以发现对于每个起点,枚举不同的终点,其实就是枚举这个后缀的所有的前缀。因此所有后缀的所有前缀的集合,就是所有子串的集合。总结到后缀的所有前缀,就可以用后缀数组了。
首先后缀数组可以帮我们将所有后缀排序,同时我们还能知道所有排名相邻的两个后缀的最大前缀是多少。
设第 $i$ 个后缀的长度为 $len_i$
我们从第一个后缀开始计数,统计第一个后缀有多少个前面没有出现过的不同前缀,由于这是枚举的第一个后缀,显然是有 $len_1$ 个没出现过的前缀。然后看第二个后缀,这就取决于 $height_2$,在 $height_2$ 以内的前缀都一定在前面出现过,而在 $height_2$ 之外的所有前缀一定没有出现过。而这个特性是一个一般性的特性。对于第 $i$ 个后缀,他在 $height_i$ 以外的后缀都一定都没有在前面出现过,假如 $height_i$ 以外的某个前缀在前面出现过,则通过夹逼准则,很容易就能证明出 $height_i$ 一定 $\geq$ 这个前缀,这就矛盾了,因此一定没有出现过。
因此对于第 $i$ 个后缀,新出现的子串数量就是没有出现过的前缀数量,也就是 $len_i-height_i$。枚举所有的后缀累加所有新出现子串的数量就能得到不同子串的数量。
如果再加上修改操作,如果我们在后面加一个字符,由于这个字符会出现在前面每一个字符的最后,因此他会影响所有前面的后缀,这样考虑非常的不方便,因此我们可以考虑将整个字符串反转过来,因为原串的不同子串的数量是等于反转之后的,所有答案不会发生改变,此时就相当于是每次往前面加一个字符,那么相当于每次只加了一个新的后缀,对于后面已有的后缀没有任何影响,有了这个思路之后,我们再改变思路。
我们将字符串反转之后,再倒着去做所有问题,我们每次从前面将字符一个一个删去,而每次删去前面一个字符,相当于是删去了一个后缀,此时如果我们能动态的维护 $height$ 数组,那么每次的 $\sum_{i=1}^n len_i-height_i$ 就是答案。
因此接下来的问题就是每次删掉一个后缀然后动态维护 $height$ 数组。最开始我们有 $n$ 个后缀,有一个长度为 $n$ 的 $height$ 数组,此时我们删去一个后缀其实就是将后缀数组中的某一个后缀删掉,此时整个后缀数组的顺序是没有变化的,因此我们只需要动态的去维护删除一个后缀即可,要想快速的删除一个元素,可以用双链表来维护,因此 $sa$ 数组是用双链表来维护的,然后我们还要考虑如何维护 $height$ 数组,假设现在要删除排名是 $i$ 的后缀,那么影响的其实只有 $height_i$ 和 $height_{i+1}$,原本维护的是 $i$ 和 $i-1$ 的最长公共前缀以及 $i+1$ 和 $i$ 的最长公共前缀,删掉之后我们需要维护 $i+1$ 和 $i-1$ 的最长公共前缀,而最长公共前缀的一个性质就是 $lcp(i-1,i+1) = min \big( lcp(i-1,i), lcp(i,i+1) \big)$,因此我们只需要在删之前用原有的信息更新一下新的 $height$ 即可。就可以用 $O(1)$ 的时间动态维护。然后每次 $\sum_{i=1}^n len_i-height_i$ 的值我们也可以动态用一个变量来维护,每次删后缀的同时更新信息即可。这样就可以做了。
可以发现本题通过两次翻转。将一个复杂的问题变成了一个比较好维护的问题,非常巧妙。
另外,本题的数字范围非常大,因此需要离散化。
C++ 代码
#include <iostream>
#include <cstring>
#include <unordered_map>
using namespace std;
typedef long long LL;
const int N = 100010;
int n, m;
int s[N];
int sa[N], x[N], y[N], c[N], rk[N], height[N]; //后缀数组
int u[N], d[N]; //上、下方向的双链表
LL res[N]; //记录答案
int get(int x) //返回离散化后的值
{
static unordered_map<int, int> hash;
if(!hash.count(x)) hash[x] = ++m;
return hash[x];
}
void get_sa() //预处理 sa
{
for(int i = 1; i <= n; i++) c[x[i] = s[i]]++;
for(int i = 2; i <= m; i++) c[i] += c[i - 1];
for(int i = n; i >= 1; i--) sa[c[x[i]]--] = i;
for(int k = 1; k <= n; k <<= 1)
{
int num = 0;
for(int i = n - k + 1; i <= n; i++) y[++num] = i;
for(int i = 1; i <= n; i++)
if(sa[i] > k)
y[++num] = sa[i] - k;
for(int i = 1; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[x[i]]++;
for(int i = 2; i <= m; i++) c[i] += c[i - 1];
for(int i = n; i >= 1; i--) sa[c[x[y[i]]]--] = y[i], y[i] = 0;
swap(x, y);
x[sa[1]] = 1, num = 1;
for(int i = 2; i <= n; i++)
x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++num;
if(num == n) return;
m = num;
}
}
void get_height() //预处理 height
{
for(int i = 1; i <= n; i++) rk[sa[i]] = i;
for(int i = 1, k = 0; i <= n; i++)
{
if(rk[i] == 1) continue;
if(k) k--;
int j = sa[rk[i] - 1];
while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) k++;
height[rk[i]] = k;
}
}
int main()
{
scanf("%d", &n);
for(int i = n; i >= 1; i--) scanf("%d", &s[i]), s[i] = get(s[i]); //要反转序列,倒着读入
get_sa(); //预处理 sa
get_height(); //预处理 height
LL ans = 0; //记录答案
for(int i = 1; i <= n; i++)
{
ans += n - sa[i] + 1 - height[i]; //累积最开始的答案
u[i] = i - 1, d[i] = i + 1; //初始化双链表
}
d[0] = 1, u[n + 1] = n; //设定边界
for(int i = 1; i <= n; i++) //从前往后删字符
{
res[i] = ans; //记录当前的答案
int k = rk[i], j = d[k]; //k 表示第 i 个后缀的排名,j 表示第 i 个后缀的后一个后缀的排名
ans -= n - sa[k] + 1 - height[k]; //减去排名第 k 的后缀的贡献
ans -= n - sa[j] + 1 - height[j]; //减去排名第 j 的后缀的贡献
height[j] = min(height[j], height[k]); //更新删除排名第 k 的后缀后 height[j] 的值
ans += n - sa[j] + 1 - height[j]; //加上排名第 j 的后缀的新贡献
d[u[k]] = d[k], u[d[k]] = u[k]; //将排名第 k 的后缀从双联比奥中删去
}
for(int i = n; i >= 1; i--) printf("%lld\n", res[i]); //从后往前输出答案
return 0;
}