$AC$ 自动机
$AC$ 自动机就是将 $KMP$ 用在 $Trie$ 上, 要理解 $AC$ 自动机则首先要理解 $KMP$.
$KMP$ 的关键在于 $next$ 数组, 也就是当匹配失败了, 我们要从哪里开始匹配.(这里的下标都从 $1$ 开始)
- $next[i]$ 表示以 $i$ 结尾的非平凡后缀(就是说不可以是整个串)和前缀匹配的最大长度.
- 如果当前位置是 $j$ 我们从当前位置出发匹配也就是看 $p[j + 1] == s[i]$ 是否成立, 如果匹配失败就要回退 $j = next[j]$
$KMP$ 的代码
next[0] = next[1] = 0;
for (int i = 2, j = 0; i <= m; ++ i) {
while (j && p[i] != p[j + 1]) j = next[j];
if (p[i] == p[j + 1]) j ++;
next[i] = j;
}
然后再多写一句
next[0] = next[1] = 0;
for (int i = 2; i <= m; ++ i) {
int j = next[i - 1];
while (j && p[i] != p[j + 1]) j = next[j];
if (p[i] == p[j + 1]) j ++;
next[i] = j;
}
多写了这一句并没有产生什么本质的差别, 却让我们对 $KMP$ 算法的理解更进了一步.
我们每次求 $next[i]$ 时, 首先是找到 $j = next[i - 1]$ 这个位置, 然后从 $j$ 这个位置出发往前匹配,如果匹配失败 while (j && p[i] != p[j + 1]) j = next[j];
就一直回退. 如果匹配成功 if (p[i] == p[j + 1]) j ++;
则进行记录.
从上面的分析中注意到一点, 求 $next[i]$ 的过程是从 $next[i - 1]$ 开始的, 也就是从上一层出发求解到下一层的信息. 鉴于此, 我们就可以尝试在 $Trie$ 上模仿 $KMP$ 的过程建立 $next$ 数组, 这就是 $AC$ 自动机.
hh = 0, tt = -1;
for (int i = 0; i < 26; ++ i) {
if (tr[0][i]) q[++tt] = tr[0][i];
}
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; ++ i) {
int p = tr[t][i];
if (!p) continue;
int j = next[t];
while (j && !tr[j][i]) j = next[j];
if (tr[j][i]) j = tr[j][i];
next[p] = j;
q[++tt] = p;
}
}
在此框架下的代码为
C++ 代码, 时间复杂度 $O(n)$, 完全按照 $KMP$ 的框架
#include<iostream>
#include<cstring>
using namespace std;
const int N = 10010, S = 55, M = 1000010;
char str[M];
int tr[N * S][26], cnt[N * S], ne[N * S], idx;
int q[N * S], hh, tt = -1;
int n;
void insert() {
int p = 0;
for (int i = 0; str[i]; ++ i) {
int t = str[i] - 'a';
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p]++;
}
void build() {
hh = 0, tt = -1;
for (int i = 0; i < 26; ++ i)
if (tr[0][i]) q[++tt] = tr[0][i];
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; ++ i) {
int p = tr[t][i];
if (!p) continue;
int j = ne[t];
while (j && !tr[j][i]) j = ne[j];
if (tr[j][i]) j = tr[j][i];
ne[p] = j;
q[++tt] = p;
}
}
}
int main(){
int T;
scanf("%d", &T);
while (T--) {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(ne, 0, sizeof ne);
idx = 0;
scanf("%d", &n);
while (n --) {
scanf("%s", str);
insert();
}
build();
scanf("%s", str);
int res = 0;
for (int i = 0, j = 0; str[i]; ++ i) {
int t = str[i] - 'a';
while(j && !tr[j][t]) j = ne[j];
if (tr[j][t]) j = tr[j][t];
int p = j;
while (p) {
res += cnt[p];
cnt[p] = 0;
p = ne[p];
}
}
printf("%d\n", res);
}
return 0;
}
基于 $Trie$ 图的优化
利用节点不存在的位置存储其(匹配失败)后该去的位置, 就可以在匹配过程中直接把所有节点指向匹配成功或失败后该去的位置.
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; ++ i) {
int p = tr[t][i];
if (!p) tr[t][i] = tr[ne[t]][i];
else {
ne[p] = tr[ne[t]][i];
q[++tt] = p;
}
}
}
这样整个过程就只需要在 $Trie$ 图上跳来跳去即可.
C++ 代码
#include<iostream>
#include<cstring>
using namespace std;
const int N = 10010, S = 55, M = 1000010;
char str[M];
int tr[N * S][26], cnt[N * S], ne[N * S], idx;
int q[N * S], hh, tt = -1;
int n;
void insert() {
int p = 0;
for (int i = 0; str[i]; ++ i) {
int t = str[i] - 'a';
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p]++;
}
void build() {
hh = 0, tt = -1;
for (int i = 0; i < 26; ++ i)
if (tr[0][i]) q[++tt] = tr[0][i];
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; ++ i) {
int p = tr[t][i];
if (!p) tr[t][i] = tr[ne[t]][i];
else {
ne[p] = tr[ne[t]][i];
q[++tt] = p;
}
}
}
}
int main(){
int T;
scanf("%d", &T);
while (T--) {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(ne, 0, sizeof ne);
idx = 0;
scanf("%d", &n);
while (n --) {
scanf("%s", str);
insert();
}
build();
scanf("%s", str);
int res = 0;
for (int i = 0, j = 0; str[i]; ++ i) {
int t = str[i] - 'a';
j = tr[j][t];
int p = j;
while (p) {
res += cnt[p];
cnt[p] = 0;
p = ne[p];
}
}
printf("%d\n", res);
}
return 0;
}
参考文献
算法提高课