稍微有点难度……不过没有孔姥爷毒瘤(
Tag
AC自动机+线段树 优化DP
题意
给定一个单词表,每个单词有权值,取出一部分(不改变顺序)使得这部分的每一个字符串都是后一个的子串,问得到的最大权值。
思路
设 f[i] 表示选了第 i 个字符串之后得到的最大值(截止)$f[i]=max(f[j])+w[i]$, iff s[j]是s[i]的子串且j<i;
反向建 fail 树,那么对于串 s[i] 的最后一位指向的孩子,均是包含s[i]的串,所以 s[i] 最后一位的子树中孩子节点均包含 s[i]
那么对串 s[1]~s[n] 进行计算时,可以把结果用线段树更新到子树,计算时只需考虑s[i]的每一位能得到的最大值(单点查询),
最后取 max+w[i] 为 s[i] 最大值,更新即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=2e4+10,M=3e5+10;
struct edge
{
int nxt,to;
}e[M];
char s[M];
int w[N],pos[N],tr[M][26],fail[M],rt,tot,head[M],cnt,n;
int in[M],out[M],tp,tx[M<<2],tf[M<<4],L,R,tmp;
int newnode()
{
tot++; memset( tr[tot],0,sizeof(tr[tot]) );
fail[tot]=0; return tot;
}
void add( int u,int v )
{
e[cnt].to=v; e[cnt].nxt=head[u]; head[u]=cnt++;
}
void insert( char *s )
{
int p=rt;
for ( int i=0; s[i]; i++ )
{
int ch=s[i]-'a';
if ( !tr[p][ch] ) tr[p][ch]=newnode();
p=tr[p][ch];
}
}
void build()
{
queue<int> q; q.push(rt);
while ( !q.empty() )
{
int now=q.front(); q.pop();
if ( now!=rt ) add( fail[now],now );
for ( int i=0; i<26; i++ )
if ( tr[now][i] )
{
if ( now!=rt ) fail[tr[now][i]]=tr[fail[now]][i];
q.push( tr[now][i] );
}
else tr[now][i]=tr[fail[now]][i];
}
}
void dfs( int now )
{
in[now]=++tp;
for ( int i=head[now]; i; i=e[i].nxt )
dfs( e[i].to );
out[now]=tp;
}
void pushdown( int i )
{
if ( !tf[i] ) return;
int pre=tf[i];
tf[i<<1]=max( tf[i<<1],pre ); tf[i<<1|1]=max( tf[i<<1|1],pre );
tx[i<<1]=max( tx[i<<1],pre ); tx[i<<1|1]=max( tx[i<<1|1],pre );
tf[i]=0;
}
int query( int l,int r,int p )
{
if ( l==r ) return tx[p];
int mid=(l+r)>>1;
pushdown(p);
if ( L<=mid ) return query( l,mid,p<<1 );
else return query( mid+1,r,p<<1|1 );
}
void update( int l,int r,int p )
{
if ( L<=l && r<=R )
{
tf[p]=max( tf[p],tmp ); tx[p]=max( tx[p],tmp ); return;
}
int mid=(l+r)>>1; pushdown(p);
if ( L<=mid ) update( l,mid,p<<1 );
if ( R>mid ) update( mid+1,r,p<<1|1 );
tx[p]=max( tx[p<<1],tx[p<<1|1] );
}
void init()
{
tot=-1; cnt=1; tp=0; rt=newnode();
memset( head,0,sizeof(head) ); memset( fail,0,sizeof(fail) );
memset( tx,0,sizeof(tx) ); memset( tf,0,sizeof(tf) );
}
int main()
{
int T; scanf( "%d",&T );
for ( int cas=1; cas<=T; cas++ )
{
init(); scanf( "%d",&n );
for ( int i=1; i<=n; i++ )
{
scanf( "%s%d",s+pos[i-1],w+i );
pos[i]=pos[i-1]+strlen(s+pos[i-1]); insert( s+pos[i-1] );
}
build(); dfs( rt ); int ans=0;
for ( int i=1; i<=n; i++ )
{
tmp=0; int now=rt;
for ( int j=pos[i-1]; j<pos[i]; j++ )
{
now=tr[now][s[j]-'a']; L=R=in[now];
int res=query( 1,tp,1 ); tmp=max( tmp,res );
}
tmp+=w[i]; ans=max( ans,tmp );
L=in[now]; R=out[now]; update( 1,tp,1 );
}
printf( "Case #%d: %d\n",cas,ans );
}
}