题意
给定一个长度为 n 的字符串 s,定义以 i 开头的后缀为 Ti,求 ∑1≤i<j≤n|Ti|+|Tj|−2lcp(Ti,Tj) 的值。
1≤n≤5×105。
分析
首先把式子拆开得 ∑1≤i<j≤ni+j+2∑1≤i<j≤nlcp(Ti,Tj)。主要就是求后面的值。
我们可以先 SA 求出 height 数组。我们知道 lcp(Ti,Tj)=min。考虑对于每个 k 会用几次。这个就是找到左边最后一个小于 height_k 的位置,右边第一个大于 height_k 的位置。用单调栈即可。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define il inline
#define N 500005
#define int ll
il int rd(){
int s = 0, w = 1;
char ch = getchar();
for (;ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
for (;ch >= '0' && ch <= '9'; ch = getchar()) s = ((s << 1) + (s << 3) + ch - '0');
return s * w;
}
int n, w, sa[N], oldsa[N], rk[N], oldrk[N], t[N];
int height[N], a[N], st[N], top, L[N], R[N];
char s[N];
ll ans = 0;
il void Sort(){
memcpy(oldrk, rk, sizeof rk);
for (int p = 0, i = 1; i <= n; i++)
if (oldrk[sa[i]] == oldrk[sa[i - 1]] && oldrk[sa[i] + w] == oldrk[sa[i - 1] + w]) rk[sa[i]] = p;
else rk[sa[i]] = ++p;
}
il void SA(){
for (int i = 1; i <= n; i++) t[rk[i] = s[i]]++;
for (int i = 1; i <= 127ll; i++) t[i] += t[i - 1];
for (int i = n; i >= 1; i--) sa[t[rk[i]]--] = i;
Sort();
for (w = 1; w < n; w <<= 1){
memcpy(oldsa, sa, sizeof sa);
memset (t, 0, sizeof t);
for (int i = 1; i <= n; i++) t[rk[oldsa[i] + w]]++;
for (int i = 1; i <= max(127ll, n); i++) t[i] += t[i - 1];
for (int i = n; i >= 1; i--) sa[t[rk[oldsa[i] + w]]--] = oldsa[i];
memcpy(oldsa, sa, sizeof sa);
memset (t, 0, sizeof t);
for (int i = 1; i <= n; i++) t[rk[oldsa[i]]]++;
for (int i = 1; i <= n; i++) t[i] += t[i - 1];
for (int i = n; i >= 1; i--) sa[t[rk[oldsa[i]]]--] = oldsa[i];
Sort();
}
}
signed main(){
scanf ("%s", s + 1), n = strlen(s + 1);
SA(), a[0] = a[n] = -1;
for (int i = 1, k = 0; i <= n; i++){
if (!rk[i]) continue;
k = max(k - 1, 0ll);
while (s[i + k] == s[sa[rk[i] - 1] + k]) k++;
height[rk[i]] = k;
}
for (int i = 1; i < n; i++) a[i] = height[i + 1];
st[++top] = 0;
for (int i = 1; i < n; i++){
while (top && a[st[top]] > a[i]) top--;
L[i] = st[top] + 1, st[++top] = i;
}
top = 0, st[++top] = n;
for (int i = n - 1; i >= 1; i--){
while (top && a[st[top]] >= a[i]) top--;
R[i] = st[top] - 1, st[++top] = i;
}
for (int i = 1; i < n; i++) ans += 1ll * (R[i] - i + 1) * (i - L[i] + 1) * a[i];
printf ("%lld\n", 1ll * n * (n - 1) / 2 * (n + 1) - 2 * ans);
return 0;
}