APIO2014 回文串

题目:

  给你一个由小写拉丁字母组成的字符串 。我们定义 的一个子串的存在值为这个子串在 中出现的次数乘以这个子串的长度。
  对于给你的这个字符串 ,求所有回文子串中的最大存在值。

思路:

  如果 是一个回文串,那么一定满足
  先对原串建出一个 SAM,尝试在 SAM 上匹配反串。如下图:
  
  设蓝色的为 ,红色的为 ,现在满足 ,那么中间直线框出来的 显然是一个回文串。
  那现在令 去匹配反串,每次试着算出所有以 为开头的回文串。设当前在 SAM 上匹配到 ,匹配的长度为 ,预处理 。当 时, 就是一个回文串,当前的出现次数就可能 (如果有更大的情况,就会在后面爬 的过程中算到,所以还是正确的),然后再尝试缩两个端点,只要沿着 往上就行了(要沿途更新 )。这样就得到了一个 的解法。
  如果一个在某个结点 满足 ,那么就标记这个结点。然后在沿 向上时不经过被标记的结点,显然这样复杂度就能达到 。但是正确性呢?
  如果在 在区间内,那么下一次再匹配到 一定不在区间内。感性理解:设蓝色为 ,红色为 。为了不让红色向右延伸,那 ,所以下一次匹配成功至少会左移
  再考虑不在区间内的情况,如果红色前段没有回文,那肯定不影响答案。如果有回文,那 总是能在一个时刻缩到 ,这个答案就被计入了。或者说漏掉的情况只有下图中上方的红色区间,但是这个区间肯定也会在第一个红色区间中以上面提到的方式计入。
  

代码:

#include <bits/stdc++.h>
#define rep(i, a, b) for(int i(a), i##_END_(b); i <= i##_END_; i++)
#define drep(i, a, b) for(int i(a), i##_END_(b); i >= i##_END_; i--)
#define File(_) freopen(#_ ".in", "r", stdin), freopen(#_ ".out", "w", stdout)
#define mset(a, b) memset(a, b, sizeof a)
#define mcpy(a, b) memcpy(a, b, sizeof a)
using namespace std;
template<class T> inline bool tomax(T &a, T b) {return a < b ? a = b, 1 : 0;}
template<class T> inline bool tomin(T &a, T b) {return b < a ? a = b, 1 : 0;}
typedef long long ll;
const int N = 300005, M = N * 2;

template<class T>
inline void rd(T &a) {
#define gc getchar()
    char c;
    bool f = false;
    for(c = gc; !isdigit(c); c = gc) f |= (c == '-');
    for(a = 0; isdigit(c); c = gc) a = (a << 1) + (a << 3) + c - '0';
    if(f) a = -a;
#undef gc
}

template<int N, int M, class T> struct Link {
#define erep(k, G, o) for(int k = G.HEAD[o]; k; k = G.NXT[k])
    int HEAD[N], NXT[M], tot; T W[M];
    void add(int x, T w) {NXT[++tot] = HEAD[x]; W[HEAD[x] = tot] = w;}
    T& operator[] (int x) {return W[x];}
};
Link<M, M, int> G;

int ch[M][26], par[M], len[M], mxr[M], cnt[M], tot, lst;
int newNode(int _len, int *_ch) {
    int o = ++tot;
    if(_ch) mcpy(ch[o], _ch);
    else mset(ch[o], 0);
    len[o] = _len;
    return o;
}
int extend(int c) {
    int p = lst, o = newNode(len[p] + 1, NULL);
    for(; p && !ch[p][c]; p = par[p]) ch[p][c] = o;
    if(p == 0) par[o] = 1;
    else {
        int u = ch[p][c];
        if(len[u] == len[p] + 1) par[o] = u;
        else {
            int v = newNode(len[p] + 1, ch[u]);
            par[v] = par[u];
            par[o] = par[u] = v;
            for(; ch[p][c] == u; p = par[p]) ch[p][c] = v;
        }
    }
    return lst = o;
}

void dfs(int o) {
    erep(k, G, o) {
        int v = G[k];
        dfs(v);
        cnt[o] += cnt[v];
        tomax(mxr[o], mxr[v]);
    }
}

char s[N];
bool vis[M];
ll calc(int n) {
    int o = 1, ln = 0;
    ll ans = 0;
    drep(i, n, 1) {
        int c = s[i] - 'a';
        while(o && !ch[o][c]) o = par[o], ln = len[o];
        if(!o) o = 1, ln = 0;
        else o = ch[o][c], ln++;
        int tmp = ln;
        if(i >= mxr[o] - ln + 1) 
            for(int p = o; !vis[p] && p != 1; p = par[p]) {
                tomin(ln, len[p]);
                vis[p] = (ln == len[p]);
                if(i <= mxr[p] && i >= mxr[p] - ln + 1)
                    tomax(ans, (ll) (mxr[p] - i + 1) * cnt[p]);
            }
        ln = tmp;
    }
    return ans;
}

int main() {
    File(palindrome);
    scanf("%s", s + 1);
    int n = strlen(s + 1);
    lst = newNode(0, NULL);
    rep(i, 1, n) {
        int o = extend(s[i] - 'a'); 
        cnt[o] = 1; mxr[o] = i;
    }
    rep(i, 2, tot) G.add(par[i], i);
    dfs(1);
    printf("%lld\n", calc(n));
    return 0;
}