线段树+SAM,小哥哥能看一下哪里有错吗
查看原帖
线段树+SAM,小哥哥能看一下哪里有错吗
261773
youwike楼主2021/1/1 22:31
/*
    \  | ^  ^  \
   -- | #    # \
   \_|   	\
*/
#include <iostream>
#include <cstdio>
#include <cstring>
#define ls(x) (x << 1)
#define rs(x) ((x << 1) | 1)

using namespace std;

const int N = 3e6 + 10;

struct SAM
{
    int ch[26], link, len, siz;
}tr[N];

int val[N * 4];
int las = 1, tot = 1;
int n;
int head[N], nxt[N], to[N], e_tot;
char s[N];

void link(int x, int y)
{
    nxt[++e_tot] = head[x];
    head[x] = e_tot;
    to[e_tot] = y;
}

void add(int c)
{
    int p = las, cur = ++tot;
    las = cur;
    tr[cur].len = tr[p].len + 1;
    while (p && !tr[p].ch[c])
    {
        tr[p].ch[c] = cur;
        p = tr[p].link;
    }
    if (!p) tr[cur].link = 1;
    else
    {
        int q = tr[p].ch[c];
        if (tr[q].len == tr[p].len + 1) tr[cur].link = q;
        else
        {
            int clome = ++tot;
            tr[clome] = tr[q], tr[clome].len = tr[p].len + 1;
            while (p && tr[p].ch[c] == q)
            {
                tr[p].ch[c] == clome;
                p = tr[p].link;
            }
            tr[q].link = clome, tr[cur].link = clome;
        }
    }
}

void change(int u, int l, int r, int x, int y, int v)
{
    if (x <= l && r <= y) return val[u] = max(val[u], v), void();
    int mid = (l + r) >> 1; 
    if (x <= mid) change(ls(u), l, mid, x, y, v);
    if (y > mid) change(rs(u), mid + 1, r, x, y, v);
}

int query(int u, int l, int r, int x)
{
    if (l == r) return val[u];
    int mid = (l + r) >> 1;
    if (x <= mid) return max(val[u], query(ls(u), l, mid, x));
    else return max(val[u], query(rs(u), mid + 1, r, x));
}

void get_siz(int u)
{
    for (int i = head[u]; i; i = nxt[i])
    {
        int v = to[i];
        get_siz(v);
        tr[u].siz += tr[v].siz;
    }
}

void dfs(int u)
{
    for (int i = head[u]; i; i = nxt[i])
    {
        int v = to[i];
        change(1, 1, n, tr[u].len + 1, tr[v].len, tr[v].siz);
        dfs(v);
    }
}

int main()
{
    #ifndef ONLINE_JUDGE
        freopen("in.txt", "r", stdin);
        freopen("out.txt", "w", stdout);
    #endif
    scanf("%s", s + 1);
    n = strlen(s + 1);
    for (int i = 1; i <= n; ++i) add(s[i] - 'a');
    for (int i = 2; i <= tot; ++i)
    {
        link(tr[i].link, i);
    }
    for (int i = 1, p = 1; i <= n; ++i)
    {
        p = tr[p].ch[s[i] - 'a'];
        ++tr[p].siz;
    }
    get_siz(1);
    dfs(1);
    for (int i = 1; i <= n; ++i)
    {
        printf("%d\n", query(1, 1, n, i));
    }
    return 0;
}
2021/1/1 22:31
加载中...