题目描述
给n个字典串s,再给我们一个串S,问我们用多少个字典串可以拼接出S.
1≤L,|S|,∑|s|≤3e5
样例
输入
3
aaaaa
a
aa
aaa
输出
2
分析
问我们拼接处S串的最小字典串数,考虑dp[i]表示拼接处S[1,i]的最小字典串数.
当我们要求dp[i]的时候,我们需要枚举i这个位置用哪个字典串,贪心来看,我们用长度最长的可以匹配的串去计算答案.
假设我们用sj作为S[1,i]的末尾,那该状态可以由dp[i-|sj|],dp[i-|sj|+1],…,dp[i-1]转移过来,在这些数里取一个最小值再+1就是dp[i]的答案.
那就会出现几个子问题:
1.如何求S[1,i]用哪个字典串作为末尾?
当我们去预处理S每个位置用哪个字典串时,假设我们匹配到了AC自动机的u节点,离u最近的终止标记就是S的这个位置应该用的字典串.
如果我们每匹配到一个节点就在fail链上往根跳,这样的时间复杂度会达到O(n2)不可取.
优化:我们在求AC自动机时,当我们位于Trie的第u个节点时,标记一下u的fail链上离u最近的终止标记的位置.
if(st[u])c[u]=u;//如果u自己就是终止节点,那c[u]就是他自己
else c[u]=c[fail[u]];//否则c[u]=c[fail[u]]
2.如何求dp[i-|sj|],dp[i-|sj|+1],…,dp[i-1]的最小值?
如果直接遍历,复杂度也有可能达到O(n2),不可取. //亲身经历直接遍历的话吸氧也能过
这是一个区间查询问题,可以用线段树来维护.
DP+AC自动机
C++ 代码
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5+10;
struct Seg{
int l,r;
int minn;
}seg[N<<2];
int tr[N][26],st[N],idx;
int depth[N];
int fail[N],q[N];
int n;
char S[N],s[N];
int z[N],c[N];
int dp[N];
int len;
void pushup(int u)
{
seg[u].minn=min(seg[u<<1].minn,seg[u<<1|1].minn);
}
void build(int u,int l,int r)
{
seg[u]={l,r};
if(l==r)
{
seg[u].minn=0;
return;
}
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
void modify(int u,int pos,int x)
{
if(seg[u].l==seg[u].r)
{
seg[u].minn=x;
return;
}
int mid=seg[u].l+seg[u].r>>1;
if(pos<=mid)modify(u<<1,pos,x);
else modify(u<<1|1,pos,x);
pushup(u);
}
int query(int u,int l,int r)
{
if(seg[u].l>=l&&seg[u].r<=r)return seg[u].minn;
int mid=seg[u].l+seg[u].r>>1;
int res=0x3f3f3f3f;
if(l<=mid)res=query(u<<1,l,r);
if(r>mid)res=min(res,query(u<<1|1,l,r));
return res;
}
void insert()
{
int u=0;
for(int i=0;s[i];i++)
{
int t=s[i]-'a';
if(!tr[u][t])tr[u][t]=++idx;
u=tr[u][t];
}
st[u]=1;
}
void build()
{
int hh=0,tt=0;
for(int i=0;i<26;i++)
if(tr[0][i])
{
depth[tr[0][i]]=1;
if(st[tr[0][i]])c[tr[0][i]]=tr[0][i];
q[tt++]=tr[0][i];
}
while(hh!=tt)
{
int t=q[hh++];
for(int i=0;i<26;i++)
{
int u=tr[t][i];
if(!u)tr[t][i]=tr[fail[t]][i];
else
{
depth[u]=depth[t]+1;
fail[u]=tr[fail[t]][i];
q[tt++]=u;
if(st[u])c[u]=u;
else c[u]=c[fail[u]];
}
}
}
}
void init()
{
int u=0;
for(int i=1;S[i];i++)
{
int t=S[i]-'a';
u=tr[u][t];
z[i]=c[u];
}
}
int main()
{
scanf("%d%s",&n,S+1);
len=strlen(S+1);
for(int i=1;i<=n;i++)
{
scanf("%s",s);
insert();
}
build();
build(1,0,len);
init();
memset(dp,0x3f,sizeof dp);
dp[0]=0;
for(int i=1;S[i];i++)
{
int l=i-depth[z[i]];
dp[i]=min(query(1,l,i-1)+1,0x3f3f3f3f);
modify(1,i,dp[i]);
}
if(dp[len]!=0x3f3f3f3f)printf("%d",dp[len]);
else printf("-1");
return 0;
}