Bootstrap

[BZOJ3998]TJOI2015弦论|后缀自动机

  对SAM不太熟做这题想了很久才想清楚。。大爷们的博客都写的好简(我太弱
  首先对原串建SAM。。如果能求出f[i]表示以root走到状态i的路径为开头往后能得到的串的数量,我们就可以像线段树那样的查询了(26分?)。。设num[i]为已确定的一条root到i的路径对应的子串数量,当T=0的时候,显然num[i]=1,num[root]=0;对于T=1,root到i的路径对应的串还可以作为另一个子串的后缀出现,那么这里要再次计数,还有一个性质,就是root到i的任何路径对应的串都是最长路径的一个后缀,且仅作为最长路径的后缀出现(回忆构建过程,如果不是的话会被连出虚拟节点去),而parent指针都是连在最长路径上的,因此如果最长路径串在原串中出现了k次那么所有root到i的串在原串中都出现了k次!那么num[i]显然就等于i在parent树中的子树大小了。。那么初始时所有状态num为1,在parent树上跑就可以预处理出num数组了。这里还有一点,虚拟节点的问题,由于虚拟节点nq的原节点q的parent指针指向了nq,造成了root到nq的串出现了2次这种假象,所以每到一个虚拟节点都要把num-1,那么不妨初始时就把虚拟节点的nq置为0。。到这里这道题就差不多清晰了,SAM是一个DAG,在上面做DP就可以把f算出来,f[i]=sigma(f[son[i]])+num[i],求出f就dfs下去26分这样就就好了,注意dfs到一个点i的时候,如果num[i]>=k就不要dfs下去了,因为root到i的串已经覆盖到了k,否则要把num[i]减掉再dfs下去。。

#include<cstdio>
#include<iostream>
#include<memory.h>
#define N 1000005
using namespace std;
char s[N];
int n,nd=1,k,t,now,i,ne=0,a[N][26],sz[N],fa[N],f[N],u[N],l[N],h[N],que[N];
struct edge{
    int e,next;
}ed[N];
void add(int s,int e)
{
    ed[++ne].e=e;ed[ne].next=h[s];h[s]=ne;
}
void extend(int c)
{
    int p=now,np=now=++nd;
    l[np]=l[p]+1;sz[np]=1;
    while (p&&!a[p][c]) a[p][c]=np,p=fa[p];
    if (!p) fa[np]=1;
    else
    {
        int q=a[p][c];
        if (l[p]+1==l[q]) fa[np]=q;
        else
        {
            int nq=++nd;
            l[nq]=l[p]+1;
            sz[nq]=t^1;fa[nq]=fa[q];
            fa[q]=fa[np]=nq;
            for (int i=0;i<26;i++) a[nq][i]=a[q][i];
            while (p&&a[p][c]==q) a[p][c]=nq,p=fa[p];
        }
    }
}
void build()
{
    now=1;
    for (i=1;i<=n;i++) extend(s[i]-'a');
}
void dp(int x)
{
    u[x]=1;f[x]=sz[x];
    for (int i=0;i<26;i++)
    if (a[x][i])
    {
        if (!u[a[x][i]]) dp(a[x][i]);
        f[x]+=f[a[x][i]];
    }
}
void dfs(int x,int k)
{
    if (k<=sz[x]) return;
    k-=sz[x];
    for (int i=0;i<26;i++)
    if (a[x][i])
        if (f[a[x][i]]>=k) 
        {
            printf("%c",i+'a');
            dfs(a[x][i],k);
            return;
        }
        else k-=f[a[x][i]];
}
void bfs()
{
    int head=1,tail=1,get;
    que[1]=1;
    while (head<=tail)
    {
        get=que[head++];
        for (int i=h[get];i;i=ed[i].next)
            que[++tail]=ed[i].e;
    }
    for (int i=tail;i>1;i--) sz[fa[que[i]]]+=sz[que[i]];
}
int main()
{
    scanf("%s",s+1);
    n=strlen(s+1);
    scanf("%d%d",&t,&k);
    fa[1]=0;sz[1]=0;
    memset(a,0,sizeof(a));
    build();
    memset(h,0,sizeof(h));
    if (t)
    {
        for (i=2;i<=nd;i++) add(fa[i],i);
        bfs();
    }
    memset(u,0,sizeof(u));
    dp(1);f[1]-=sz[1];sz[1]=0;
    if (f[1]<k) printf("-1");else dfs(1,k);
}
;