思路:
利用$PAM$能在$O(n)$的时间复杂度之内求出字符串$S$的所有不同回文子串的数量.
因此只需要将所有字符串依次插入回文树,对于每次插入的过程中,每个回文树节点只计算一次.
注意在插入不同字符串的时候需要将$Fail$指针指向0.同时清空辅助数组$S$.
用来避免后一个字符串和前一个字符串形成非法回文的清空.
AC Code:
#include<bits/stdc++.h>
#pragma optimize(2)
#define endl '\n'
#define ll() to_ullong()
#define string() to_string()
#define Endl endl
using namespace std;
typedef long long ll;
typedef pair<int,int>PII;
typedef unsigned long long ull;
const int M=2010;
const int P=13331;
const ll llinf=0x3f3f3f3f3f3f3f3f;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
const int N=2e6+10;
int dx[4]={0,1,0,-1};
int dy[4]={-1,0,1,0};
int k,last;
unordered_map<int,bool>mp;
ll ans;
struct PAM{
int n,last,tot;
int len[N],trie[N][26],fail[N],cnt[N],S[N],num[N];
//len[i]: 节点i所代表的回文串长度, fail[i]: 当前回文串的最长回文后缀(不包括自身)
//cnt[i]: 节点i所代表的回文串的个数, S[i]: 第i次添加的字符, num[i]: 以第i个字符为结尾的回文串个数
//last: 上一个字符构成最长回文串的位置,方便下一个字符的插入
//tot: 总结点个数 = 本质不同的回文串的个数+2, n: 插入字符的个数
int newnode(int l){
for(int i=0;i<26;i++)trie[tot][i] = 0;
cnt[tot] = 0, len[tot] = l, num[tot] = 0;
return tot++;
}
inline void init(){
tot = n = last = 0, newnode(0), newnode(-1);
S[0] = -1, fail[0] = 1;
}
int get_fail(int x){ //获取fail指针
while(S[n-len[x]-1] != S[n]) x = fail[x];
return x;
}
inline int insert(int c){ //插入字符
S[++n] = c;
int cur = get_fail(last);
//在节点cur前的字符与当前字符相同,即构成一个回文串
if(!trie[cur][c]){ //这个回文串没有出现过
int now = newnode(len[cur]+2);
fail[now] = trie[get_fail(fail[cur])][c];
trie[cur][c] = now;
num[now] = num[fail[now]]+1; //更新以当前字符为结尾的回文串的个数
}
last = trie[cur][c];
if(!mp[last])
{
mp[last]=true;
cnt[last]++; //更新当前回文串的个数
if(cnt[last]==k)ans++;
}
return num[last]; //返回以当前字符结尾的回文串的个数
}
void count(){ //统计每个本质不同回文串的个数
for(int i=tot-1;i>=0;i--)cnt[fail[i]]+=cnt[i];
}
}pam;
void solve()
{
cin>>k;
pam.init();
for(int i=1;i<=k;i++)
{
string s;cin>>s;
pam.n=pam.last=0;//清空避免和前面发生冲突
mp.clear();
for(auto c:s)pam.insert(c-'a');
}
cout<<ans<<endl;
return ;
}
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
// freopen("test.in","r",stdin);
solve();
return 0;
}
模板:
struct PAM{
int n,last,tot;
int len[N],trie[N][26],fail[N],cnt[N],S[N],num[N];
//len[i]: 节点i所代表的回文串长度, fail[i]: 当前回文串的最长回文后缀(不包括自身)
//cnt[i]: 节点i所代表的回文串的个数, S[i]: 第i次添加的字符, num[i]: 以第i个字符为结尾的回文串个数
//last: 上一个字符构成最长回文串的位置,方便下一个字符的插入
//tot: 总结点个数 = 本质不同的回文串的个数+2, n: 插入字符的个数
int newnode(int l){
for(int i=0;i<26;i++)trie[tot][i] = 0;
cnt[tot] = 0, len[tot] = l, num[tot] = 0;
return tot++;
}
inline void init(){
tot = n = last = 0, newnode(0), newnode(-1);
S[0] = -1, fail[0] = 1;
}
int get_fail(int x){ //获取fail指针
while(S[n-len[x]-1] != S[n]) x = fail[x];
return x;
}
inline int insert(int c){ //插入字符
S[++n] = c;
int cur = get_fail(last);
//在节点cur前的字符与当前字符相同,即构成一个回文串
if(!trie[cur][c]){ //这个回文串没有出现过
int now = newnode(len[cur]+2);
fail[now] = trie[get_fail(fail[cur])][c];
trie[cur][c] = now;
num[now] = num[fail[now]]+1; //更新以当前字符为结尾的回文串的个数
}
last = trie[cur][c];
if(!mp[last])
{
mp[last]=true;
cnt[last]++; //更新当前回文串的个数
if(cnt[last]==k)ans++;
}
return num[last]; //返回以当前字符结尾的回文串的个数
}
void count(){ //统计每个本质不同回文串的个数
for(int i=tot-1;i>=0;i--)cnt[fail[i]]+=cnt[i];
}
}pam;