hdu_6096 String

题目:

  Bob has a dictionary with N words in it.
  Now there is a list of words in which the middle part of the word has continuous letters disappeared. The middle part does not include the first and last character.
  We only know the prefix and suffix of each word, and the number of characters missing is uncertain, it could be 0. But the prefix and suffix of each word can not overlap.
  For each word in the list, Bob wants to determine which word is in the dictionary by prefix and suffix.
  There are probably many answers. You just have to figure out how many words may be the answer.

思路:

  统计一种字符串的出现次数,首先想到的应该是AC自动机,但是AC自动机不是用来查询前缀和后缀的,所以试着对字符串进行变化:s=s+'_'+s,比如 就变成了 ,同样对于一组前缀和后缀,也做类似的操作,如 就变成 。这样就能直接匹配了。
  题目中要求前缀和后缀不能重叠,那么也就是说字符串长度大于前缀长度后缀长度。对所有的前后缀建立一个AC自动机,在匹配时直接在当前点的贡献集合中加入一个字符串长度,然后用树状数组在fail树上dfs做差即可。复杂度

代码:

#include <bits/stdc++.h>
#define mset(a,b) memset(a,b,sizeof a)
#define _(...) (void)(__VA_ARGS__)
#define rep(i,a,b) for(int i(a),i##_END_(b);i<=i##_END_;i++)
#define drep(i,a,b) for(int i(a),i##_END_(b);i>=i##_END_;i--)
template<class T> inline bool tomax(T &a,T b){return a<b?a=b,1:0;}
template<class T> inline bool tomin(T &a,T b){return b<a?a=b,1:0;}
using std::max;using std::min;
typedef long long ll;
const int N=100005,M=500005,ALL=M*2+1;

bool cur1;

template<int N,class T> struct Queue{
    int hed,tai;T q[N];
    void push(T x){q[++tai]=x;}
    T pop(){return q[hed++];}
    void clear(){hed=1;tai=0;}
    bool empty(){return hed>tai;}
};

struct Node{
    Node *ch[27],*fail;
    int len,ans;
    std::vector<int> upd;
    std::vector<Node*> son;
}node[ALL],*rt;

int tot_node;
Node* newNode(){
    Node *p=node+(++tot_node);
    mset(p->ch,0);p->fail=NULL;
    p->len=0;p->ans=0;
    p->upd.clear();p->son.clear();
    return p;
}

int toNum(char c){
    if(c=='_')return 26;
    return c-'a';
}

Node* insert(char *s,int len){
    Node *o=rt;
    while(*s){
        int c=toNum(*s++);
        if(!o->ch[c])
            o->ch[c]=newNode();
        o=o->ch[c];
    }
    o->len=len;
    return o;
}

Queue<ALL,Node*> que;
void getFail(){
    que.clear();
    rt->fail=rt;
    rep(i,0,26)
        if(!rt->ch[i])
            rt->ch[i]=rt;
        else {
            rt->ch[i]->fail=rt;
            rt->son.push_back(rt->ch[i]);
            que.push(rt->ch[i]);
        }
    while(!que.empty()){
        Node *o=que.pop();
        rep(i,0,26)
            if(!o->ch[i])
                o->ch[i]=o->fail->ch[i];
            else {
                o->ch[i]->fail=o->fail->ch[i];
                o->fail->ch[i]->son.push_back(o->ch[i]);
                que.push(o->ch[i]);
            }
    }
}

#define low(x) (-(x)&(x))
struct BIT{
    int c[N],a;
    void upd(int w){
        for(;w;w-=low(w))
            c[w]++;
    }
    int qry(int w){
        for(a=0;w<N;w+=low(w))
            a+=c[w];
        return a;
    }
    void clear(){mset(c,0);}
}sum;

int ans[N];
void dfs(Node *o){
    if(o->len)
        o->ans=sum.qry(o->len);
    for(int x:o->upd)
        sum.upd(x);
    for(Node *p:o->son)
        dfs(p);
    if(o->len)
        o->ans=sum.qry(o->len)-o->ans;
}

void find(char *s,int len){
    Node *o=rt;
    while(*s){
        int c=toNum(*s++);
        o=o->ch[c];
        o->upd.push_back(len);
    }
}

void change(char *s1,char *s2,char *t){
    while(*s1)
        (*t++)=(*s1++);
    (*t++)='_';
    while(*s2)
        (*t++)=(*s2++);
    (*t++)='\0';
}

char w[M],s[N],p[N],t[N*2],*cpt[N];
Node *pnt[N];

bool cur2;

int main(){
//    printf("%lf\n",(&cur2-&cur1)/1024.0/1024);
    int cas,n,m;
    scanf("%d",&cas);
    while(cas--){
        tot_node=0;
        rt=newNode();
        scanf("%d%d",&n,&m);
        cpt[1]=w;
        rep(i,1,n){
            scanf("%s",cpt[i]);
            cpt[i+1]=cpt[i]+strlen(cpt[i])+1;
        }
        rep(i,1,m){
            scanf("%s%s",s,p);
            change(p,s,t);
            pnt[i]=insert(t,strlen(s)+strlen(p));
        }
        getFail();
        rep(i,1,n){
            change(cpt[i],cpt[i],t);
            find(t,strlen(cpt[i]));
        }
        dfs(rt);
        rep(i,1,m)
            printf("%d\n",pnt[i]->ans);
    }
    return 0;
}