不知道是想错了还是打错了。
想法:正反两次 PAM,可以找到每个回文串的起点和终点,用差分统计每个位置出现几次。
结果:sub1 WA了一个?不是很理解为什么。其他MLE TLE的不是很重要。
#include<iostream>
#include<cstring>
using namespace std;
const int N=3e6+10;
struct edge_node{
int from,to;
int nxt;
};
struct edge{
edge_node e[N];
int head[N],tot;
edge(){memset(head,-1,sizeof head);}
inline void add(int from,int to){
e[tot].from=from;
e[tot].to=to;
e[tot].nxt=head[from];
head[from]=tot++;
}
void clear(){
memset(e,0,sizeof e);
memset(head,-1,sizeof head);
tot=0;
}
};
edge e;
string s;
int a[N];
int fail[N],ch[N][27],lth[N],tot=1;
int sum[N];
long long ans[N];
__int128 res;
bool bo[N];
void dfs1(int u){
if(sum[u]*2>=lth[u]&&u>=2) bo[u]=true;
for(int i=0;i<=26;i++){
if(!ch[u][i]) continue;
int v=ch[u][i];
sum[v]=sum[u];
if(i==26)
if(u!=1) sum[v]+=2;
else sum[v]++;
dfs1(v);
}
}
int dep[N];
void dfs2(int u){
for(int i=e.head[u];~i;i=e.e[i].nxt){
int v=e.e[i].to;
dep[v]=dep[u]+bo[v];
dfs2(v);
}
}
int sta[N],top;
signed main(){
int len;
cin>>len>>s;
for(int i=1;i<=len;i++)
a[i]=(s[i-1]=='?'?26:s[i-1]-'a');
int last=0;
fail[0]=1,fail[1]=0;
lth[0]=0,lth[1]=-1;
a[0]=-1;
for(int i=1;i<=len;i++){
while(a[i-lth[last]-1]!=a[i])
last=fail[last];
if(!ch[last][a[i]]){
lth[++tot]=lth[last]+2;
int j=fail[last];
while(a[i-lth[j]-1]!=a[i])
j=fail[j];
fail[tot]=ch[j][a[i]];
ch[last][a[i]]=tot;
}
last=ch[last][a[i]];
}
dfs1(0),dfs1(1);
for(int i=2;i<=tot;i++)
e.add(fail[i],i);
dfs2(0),dfs2(1);
last=0;
for(int i=1;i<=len;i++){
while(a[i-lth[last]-1]!=a[i])
last=fail[last];
last=ch[last][a[i]];
if(bo[last]) ans[i+1]-=dep[last];
}
e.clear();
memset(fail,0,sizeof fail);
memset(lth,0,sizeof lth);
memset(ch,0,sizeof ch);
memset(bo,false,sizeof bo);
memset(dep,0,sizeof dep);
memset(sum,0,sizeof sum);
tot=1,last=0;
fail[0]=1,fail[1]=0;
lth[0]=0,lth[1]=-1;
a[len+1]=-1;
for(int i=len;i>=1;i--){
while(a[i+lth[last]+1]!=a[i])
last=fail[last];
if(!ch[last][a[i]]){
lth[++tot]=lth[last]+2;
int j=fail[last];
while(a[i+lth[j]+1]!=a[i])
j=fail[j];
fail[tot]=ch[j][a[i]];
ch[last][a[i]]=tot;
}
last=ch[last][a[i]];
}
dfs1(0),dfs1(1);
for(int i=2;i<=tot;i++)
e.add(fail[i],i);
dfs2(0),dfs2(1);
last=0;
for(int i=len;i>=1;i--){
while(a[i+lth[last]+1]!=a[i])
last=fail[last];
last=ch[last][a[i]];
if(bo[last]) ans[i]+=dep[last];
}
for(__int128 i=1;i<=len;i++){
ans[i]+=ans[i-1];
__int128 k=i*ans[i];
if(a[i]==26) res+=k;
}
while(res){
sta[++top]=res%10;
res/=10;
}
while(top) cout<<sta[top--];
}
/*
10
?????????
4
????
*/