题目描述
T组测试数据(T<=10)
每组测试数据:
N,M —— 起始的字符串数量,操作数N,M<=1e5
接下来N行,每行一个字符串,表示字典中本来就有的字符串
接下来M行,每行一个数字op和一个字符串s
如果op==1,表示往字典中加入一个字符串s
如果op==2,表示查询有多少字典串在s中出现过,如果一个字典串多次出现,则计算多次.
|s|<=1e5,Σ|s|<=3e6
样例
输入
2
1 3
abc
2 abcabc
1 aba
2 abababc
2 6
abc
bcd
2 abcd
2 bcd
1 abcd
2 abcd
2 abc
2 bcd
输出
2
3
2
1
3
1
1
分析
AC自动机
一.先只考虑查询操作,给我们一个串s,如何求出共有多少字典串出现在s中了呢?
1.要求这个问题我们首先得知道Trie树中哪些节点是字典串的末尾,所有在末尾处打上st标记
2.当串匹配到AC自动机的一个节点u,就说明Trie树中u对应的前缀在串s中出现过,不仅这样,u的fail链上的节点对应的前缀也都在串s中出现过,故我们需要把u的fail链上字典串末尾的节点(即st为1)计算到答案当中.也就是说,我们需要求每个节点的fail链上有多少终止节点,不考虑修改时直接在fail树上dfs一下将结果预处理出来即可即可.
二.现在考虑带修改操作
我们知道AC自动机是一种离线型数据结构,不能够快速的支持添加字符串.
1.我们把所有操作保存下来,将所有串(包括后来添加的串)都加到Trie树中,只将原始的串打上标记,后来添加的串不打标记.
2.然后我们遍历所有操作,遇到添加字符串的操作,直接在对应的节点打上标记.
但是这样的话,询问时,不能够使用上面所说的预处理的结果,因为终止节点在变化,所以一个节点的fail链上的终止节点也是变化的,我们可以用dfs序加树状数组来求.
三.如何用dfs序加树状数组求一个节点的fail链有多少终止节点.
我们在fail树上预处理出每个节点的dfs序(zz[])和每个节点的子树中的dfs序最大值(cc[]).当我们给一个节点打上终止标记时,就是让这个节点的子树中的每一个节点的答案+1.问题就变成了区间修改,单点查询.很自然的想到树状数组.
add(zz[u],1),add(cc[u]+1,-1);
代码
AC自动机
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5+10,M = 1e6+10;//M的大小是试出来的
struct Query{
string s;
int id,type;
int ans;
}query[N];
int n,m;
int zz[M],cc[M],p;
int h[M],e[M],ne[M],idd;
int tr[M][26],st[M],idx;
int fail[M],q[M];
int T;
string ss;
int sum[M];
int insert(string s)
{
int u=0;
for(int i=0;s[i];i++)
{
int t=s[i]-'a';
if(!tr[u][t])tr[u][t]=++idx;
u=tr[u][t];
}
st[u]++;
return u;
}
void add(int a,int b)
{
e[idd]=b,ne[idd]=h[a],h[a]=idd++;
}
void build()
{
int hh=0,tt=0;
for(int i=0;i<26;i++)
if(tr[0][i])
{
q[tt++]=tr[0][i];
add(0,tr[0][i]);
}
while(hh!=tt)
{
int t=q[hh++];
for(int i=0;i<26;i++)
{
int u=tr[t][i];
if(!u)tr[t][i]=tr[fail[t]][i];
else
{
fail[u]=tr[fail[t]][i];
add(fail[u],u);
q[tt++]=u;
}
}
}
}
void getdfs(int u)
{
cc[u]=p;
zz[u]=p++;
for(int i=h[u];~i;i=ne[i])
{
getdfs(e[i]);
cc[u]=cc[e[i]];
}
}
int lowbit(int x){return x&-x;}
void addd(int x,int c)
{
for(int i=x;i<=idx;i+=lowbit(i))sum[i]+=c;
}
int queryy(int x)
{
int res=0;
for(int i=x;i>0;i-=lowbit(i))
res+=sum[i];
return res;
}
int queryy(string s)
{
int res=0;
int u=0;
for(int i=0;s[i];i++)
{
int t=s[i]-'a';
u=tr[u][t];
res+=queryy(zz[u]);
}
return res;
}
int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
idx=0;
idd=0;
p=0;
memset(h,-1,sizeof h);
memset(tr,0,sizeof tr);
memset(st,0,sizeof st);
memset(fail,0,sizeof fail);
memset(zz,0,sizeof zz);
memset(cc,0,sizeof cc);
memset(sum,0,sizeof sum);
for(int i=1;i<=n;i++)
{
cin>>ss;
insert(ss);
}
for(int i=1;i<=m;i++)
{
cin>>query[i].type>>query[i].s;
if(query[i].type==1)
{
query[i].id=insert(query[i].s);st[query[i].id]--;
}
}
build();
getdfs(0);
for(int i=0;i<idx;i++)
{
if(st[q[i]])
{
addd(zz[q[i]],1);
addd(cc[q[i]]+1,-1);
}
}
for(int i=1;i<=m;i++)
{
if(query[i].type==1)
{
st[query[i].id]++;
addd(zz[query[i].id],1);
addd(cc[query[i].id]+1,-1);
}
else printf("%d\n",queryy(query[i].s));
}
}
return 0;
}