0、前言
今天专业课的数据结构复习到$KMP$,想想以前学的虽然知其然,但是不知其所以然,于是决定再学一遍,顺便详细记录一下(怕忘了),第二次学的时候总是有新收获(其实当初不知道学了多少遍才学会)
1、为什么朴素的匹配方式低效
bool check(string p, string s) //检查s中内否找到p
int n = s.size(), m = p.size();
for (int i = 1; i + m - 1 <= n; i++) {
int j = 1;
while (j <= m && p[j] == s[i + j - 1]) j++;
if (j == m + 1) return true;
}
return false;
其实传统的匹配算法之所以低效是因为会出现太多无意义的比较,即一些我们可以预知的必然无法匹配的情况。
例如$P=abab, S=abaaabab$
若进行一般的匹配前三次匹配应该是这样:
- 第一次匹配:
$P=abab$
$S=abaaabab$
$P[0]=S[0],P[1]=S[1],P[2]=S[2],P[3]!=S[3]$,失败 - 第二次匹配:
$\ \ P=abab$
$S=abaaabab$
$P[0]!=S[1]$,失败 - 第三次匹配:
$\ \ \ P=abab$
$S=abaaabab$
$P[0]=S[2],P[1]!=S[3]$,失败
其实,通过第一次匹配我们可以发现$P[0]!=P[1]=S[1]$,即$P[0]!=S[1]$,因此就算不进行第二次匹配我们也能预知第二次匹配会失败。
同样,通过第一次匹配我们可以发现$P[1]=P[3]!=S[3]$,即$P[1]!=S[3]$,即第三次匹配必然会失败。
2、如何进行高效的匹配
不难发现在某次匹配失败之后,其实我们已经知道了$S[i-j,i-1]$与$P[1,j]$是相同的
这时问题就在于在匹配失败后,如何不去进行那些必然会失败的操作,也就是P串至少要向后移动多少位才有可能成功匹配。
假设移动若干位后可以继续匹配,我们发现红色这段是相同,记红色长度为$k$,则$P[1,k]=P[j-k-1,j]$,
也就是对于$P[1,j]$这段来说相同的前后缀子串.
因此如果想知道$P$最少需要后移多少位,只需要最大化$k$,也就是要求出$P[1,j]$这段的最长相同前后缀长度。
到现在我们其实发现了,$S$串是什么其实并不重要,我们真正需要知道的是$P$在第$j+1$位匹配失败后,需要后移多少位,而这可以通过预处理$P$串来实现。
总而言之,不论$S$串是什么,只要$P$串给定,则在某次匹配失败后,$P$串后移的位数就已经确定。
(由于在实际操作中,串无法真正“后移”,移动的应该是指针$j$,因此串的后移就等价于指针$j$的前”跳”,也就是在某次匹配失败后指针$j$需要找到它的新的定位$ne[j]$)
进而也有了匹配的代码
//ne[j]就代表p[1,j]这段的最长相同前后缀长度,由于字符串下标从1开始,因此ne[j]也就代表了匹配失败后j的归宿
for(int i = 1 , j = 0; i <= m ; i++)
{
while(j && p[j + 1] != s[i]) j = ne[j]; //如果p[j+1]与s[i]不匹配,j要不断向前跳
if(p[j + 1] == s[i]) j++; //直到两者匹配,j继续向后移动
if(j == n){
cout << i - j << ' ';
j = ne[j];
}
}
3、如何预处理出$ne[]$
假设已经处理了前$i-1$位
- 若$p[j+1]=p[i]$
毫无疑问此时$ne[i]=++j$ - 若$p[j+1]!=p[i]$
此时$j$需要找它的归宿$ne[j]$,进而继续比较,若不存在相同前后缀子串那就只能从头开始
进而得出预处理代码
//p[1,1]只有一个字符因此ne[1]=0
for(int i = 2 , j = 0 ; i <= n ; i++)
{
while(j && p[j + 1] != p[i]) j = ne[j];
//如果p[j+1]和p[i]不匹配,j要不断向前跳。最差情况是p[1,i]不存在相同的前后缀,ne[i]=0
if(p[j + 1] == p[i]) j++;
ne[i] = j;
}
4、完整代码
#include <iostream>
using namespace std;
const int N = 1000010;
char s[N] , p[N];
int ne[N];
int main()
{
int n , m;
cin >> n >> p + 1 >> m >> s + 1;
for(int i = 2 , j = 0 ; i <= n ; i++)
{
while(j && p[j + 1] != p[i]) j = ne[j];
if(p[j + 1] == p[i]) j++;
ne[i] = j;
}
for(int i = 1 , j = 0; i <= m ; i++)
{
while(j && p[j + 1] != s[i]) j = ne[j];
if(p[j + 1] == s[i]) j++;
if(j == n){
cout << i - j << ' ';
j = ne[j];
}
}
return 0;
}