LOJ能过,到洛谷只有30,剩下的全部TLE,希望各位大佬能帮忙解答!
#include <bits/stdc++.h>
using namespace std;
#define int long long
int n,k,ans,sa[500005],ht[500005];
int rnk[500005],tmp[500005];
char str[500005];
int top,sk[500005],ll[500005],rr[500005];
bool cmp(int i,int j)
{
if(rnk[i]!=rnk[j])
return rnk[i]<rnk[j];
int ri=(i+k<=n ? rnk[i+k] : -1);
int rj=(j+k<=n ? rnk[j+k] : -1);
return ri<rj;
}
void getsa(int n,char *str)
{
for(int i=1;i<=n;i++)
sa[i] = i, rnk[i] = str[i];
for(k=1;k<=n;k*=2)
{
sort(sa+1,sa+1+n,cmp);
tmp[sa[1]] = 1;
for(int i=2;i<=n;i++)
tmp[sa[i]] = tmp[sa[i-1]]+cmp(sa[i-1],sa[i]);
for(int i=1;i<=n;i++)
rnk[i] = tmp[i];
}
}
void gethgt(int n,char *str)
{
for(int i=1;i<=n;i++)
rnk[sa[i]] = i;
int h=0;
ht[1] = 0;
for(int i=1;i<=n;i++)
{
int j=sa[rnk[i]-1];
if(h>0)
h--;
for(;j+h<=n&&i+h<=n;h++)
if(str[j+h]!=str[i+h])
break;
ht[rnk[i]] = h;
}
}
signed main()
{
scanf("%s",str+1);
n = strlen(str+1);
getsa(n,str), gethgt(n,str);
top = 1, sk[1] = 1;
for(int i=2;i<=n;i++)
{
while(top && ht[sk[top]]>ht[i])
rr[sk[top]] = i, top--;
ll[i] = sk[top];
sk[++top] = i;
}
while(top)
rr[sk[top]] = n+1, top--;
ans = n*(n-1)*(n+1)*1ll/2;
for(int i=2;i<=n;i++)
ans -= 2ll*(rr[i]-i)*(i-ll[i])*ht[i];
cout<<ans<<endl;
return 0;
}