题目描述
难度分:2900
输入n(1≤n≤2×106)和长度为n的字符串s,只包含小写英文字母。
从s中选两个重叠(有公共部分)的非空回文子串,有多少种选法?答案模51123987。
输入样例1
4
babb
输出样例1
6
输入样例2
2
aa
输出样例2
2
算法
Manacher算法
难得又碰到一道会做的周五茶,达成了一周5道茶AK的成就。
首先可以想到,要找两个重叠的回文串不好找,但是要找两个不重叠的回文串非常好找,所以我们求对立事件的方案数,用总方案数减去对立事件的方案数就能得到答案。
假设我们可以得到一个数组pre,pre[i]表示以s[i]为右端点的回文子串数量,这样整个串的回文子串数量即为tot=Σn−1i=0pre[i]。还可以得到一个数组suf,suf[i]表示以s[i]为左端点的回文子串数量,对suf求一个后缀和,suf[i]就表示s的后缀[i,n)上的回文子串数量。
然后就可以枚举左边那个回文串的右端点i∈[0,n−2],但每个选择s[i]结尾的回文串时,另一个串可以在后缀[i+1,n−1]上选,方案数为pre[i]×suf[i+1]。所以最终答案就是tot−Σn−2i=0pre[i]×suf[i+1]。此时关键就在于怎么求pre,求suf只需要把s反转一下跑一遍相同的流程再把suf反转回来即可。
通过Manacher算法,可以知道在某个回文中心的最长回文子串的下标区间[l,r]。那么[l+1,r−1],[l+2,r−2],… 这些也是回文子串。然后统计以s[i]结尾的回文子串个数。对于上述[l,r]及其内部的回文子串,右端点最小是⌈l+r2⌉,最大是r,所以右端点在 [⌈l+r2⌉,r]中的回文子串个数都要加一。这可以用差分数组维护,最后对这个差分数组求一下前缀和就能得到pre。
复杂度分析
时间复杂度
运行两次Manacher算法获得pre和suf两个数组,时间复杂度为O(n)。接下来O(n)把所有pre[i]累加起来得到整个s串中回文子串的数量,并对suf求后缀和,时间复杂度为O(n)。最后再遍历s串,枚举其中一个串的右端点求“选两个互不相交的回文串方案数”时间复杂度为O(n)。因此,整个算法的时间复杂度为O(n)。
空间复杂度
Manacher算法需要构建一个在s的每个相邻字符之间插入特殊字符的新串,空间复杂度为O(n)。pre和suf两个数字的空间复杂度也为O(n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int MOD = 51123987;
vector<LL> manacher(const std::string& s) {
int n = s.length();
string T = "#";
for (char c : s) {
T += c;
T += '#';
}
int m = T.length();
vector<int> p(m, 0);
int C = 0, R = 0;
for(int i = 0; i < m; ++i) {
int mirror = 2 * C - i;
if(i < R) {
p[i] = min(R - i, p[mirror]);
}
int a = i + 1 + p[i];
int b = i - 1 - p[i];
while(a < m && b >= 0 && T[a] == T[b]) {
p[i]++, a++, b--;
}
if(i + p[i] > R) {
C = i;
R = i + p[i];
}
}
vector<int> diff(n, 0);
for(int i = 0; i < m; i++) {
int r = p[i];
if (r == 0) continue;
int j_start = i;
int j_end = i + r - 1;
int first_j = (j_start % 2 == 1) ? j_start : (j_start + 1 <= j_end) ? j_start + 1 : j_end + 1;
int last_j = (j_end % 2 == 1) ? j_end : (j_end - 1 >= j_start) ? j_end - 1 : j_start - 1;
if (first_j > j_end || last_j < j_start) continue;
int start_k = (first_j - 1) / 2;
int end_k = (last_j - 1) / 2;
start_k = max(start_k, 0);
end_k = min(end_k, n - 1);
if(start_k > end_k) continue;
diff[start_k]++;
if(end_k + 1 < n) {
diff[end_k + 1]--;
}
}
vector<LL> dp(n, 0);
dp[0] = diff[0];
for(int i = 1; i < n; ++i) {
dp[i] = dp[i - 1] + diff[i];
}
return dp;
}
int main() {
int n;
string s;
cin >> n >> s;
vector<LL> pre = manacher(s);
reverse(s.begin(), s.end());
vector<LL> suf = manacher(s);
__int128 tot = pre[0];
for(int i = 1; i < n; i++) {
tot += pre[i];
}
reverse(suf.begin(), suf.end());
for(int i = n - 2; i >= 0; i--) {
suf[i] += suf[i + 1];
suf[i] %= MOD;
}
__int128 ans = tot * (tot - 1) / 2 % MOD;
for(int i = 0; i < n - 1; i++) {
ans = (ans - (pre[i] * suf[i + 1] % MOD) + MOD) % MOD;
}
cout << (int)ans << endl;
return 0;
}