事实上这题和 差异 几乎是一样的。
他让我们求两个字符串各取子串相同的方案数,这很麻烦,我们现在只知道一个子串中的方案数。
但是可以把两个字符串合并,为了防止算重中间加一个特殊符号例如 #
,再用容斥减去两个字符串内部的情况即可。
#include <bits/stdc++.h>
using namespace std;
const int N = 4e5 + 5;
int n, m;
char a[N], b[N];
struct SA {
int n;
char s[N];
int sa[N], rk[N], x[N], y[N];
int cnt[N];
void add(char ch) { s[++n] = ch; }
void init_SA() {
int v = 128;
for (int i = 0; i <= v; i++) cnt[i] = 0;
for (int i = 1; i <= n; i++) ++cnt[x[i] = s[i]];
for (int i = 1; i <= v; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[ cnt[x[i]]-- ] = i;
for (int len = 1; ; len <<= 1) {
int tot = 0;
for (int i = n - len + 1; i <= n; i++) y[++tot] = i;
for (int i = 1; i <= n; i++)
if (sa[i] > len) y[++tot] = sa[i] - len;
for (int i = 0; i <= v; i++) cnt[i] = 0;
for (int i = 1; i <= n; i++) ++cnt[x[i]];
for (int i = 1; i <= v; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[ cnt[x[y[i]]]-- ] = y[i], y[i] = 0;
swap(x, y), tot = 0;
for (int i = 1; i <= n; i++) {
if (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + len] == y[sa[i - 1] + len]) x[sa[i]] = tot;
else x[sa[i]] = ++tot;
}
v = tot;
if (v == n) break;
}
}
int ht[N];
void init_height() {
for (int i = 1; i <= n; i++) rk[sa[i]] = i;
for (int i = 1, j = 0, now = 0; i <= n; i++) {
if (rk[i] == 1) { ht[rk[i]] = 0; continue; }
j = sa[rk[i] - 1];
if (now) now--;
while (i + now <= n && j + now <= n && s[i + now] == s[j + now]) now++;
ht[rk[i]] = now;
}
}
long long sum;
int stk[N], top = 0;
int l[N], r[N];
long long solve() {
init_SA(), init_height();
sum = 0;
for (int i = 2; i <= n; i++) {
while (top && ht[i] <= ht[stk[top]]) r[stk[top]] = i, top--; //注意取等和不取等条件
stk[++top] = i;
}
while (top) r[stk[top--]] = n + 1;
for (int i = n; i >= 2; i--) {
while (top && ht[i] < ht[stk[top]]) l[stk[top]] = i, top--;
stk[++top] = i;
}
while (top) l[stk[top--]] = 1;
for (int i = 2; i <= n; i++) sum += ht[i] * 1ll * (r[i] - i) * 1ll * (i - l[i]);
return sum;
}
} t, ta, tb;
int main() {
scanf("%s\n %s", a + 1, b + 1), n = strlen(a + 1), m = strlen(b + 1);
for (int i = 1; i <= n; i++) t.add(a[i]), ta.add(a[i]);
t.add('#');
for (int i = 1; i <= m; i++) t.add(b[i]), tb.add(b[i]);
printf("%lld\n", t.solve() - ta.solve() - tb.solve() );
return 0;
}