本蒟蒻用的是 hash 和 kmp。
只 WA 了第 11 个点。
求助。
代码:
#include<bits/stdc++.h>
#define ull unsigned long long
#define int long long
using namespace std;
const int N=3e6+10,p=(1ll<<32);
string s1,s2;
int n,m,nxt[N],s[N],ss[N],sss[N],ans;
ull ha[N],hb[N],t[N],b=131;
int gets(int l,int r){
if(l>r)return 0;
int mid=(l+r)/2;
int L=((ss[mid]-ss[l-1])%p-(s[mid]-s[l-1])*(l-1)%p+p)%p;
int R=((sss[mid+1]-sss[r+1])%p-(s[r]-s[mid])*(n-r)%p+p)%p;
// printf("%d %d %d\n",l,r,(L+R)%p);
return (L+R)%p;
}
signed main(){
cin>>n>>m>>s1>>s2;
// n=s1.size(),m=s2.size();
s1=' '+s1,s2=' '+s2;
t[0]=1;
for(int i=1;i<=n;i++)t[i]=t[i-1]*b;
for(int i=2,j=0;i<=m;i++){
while(j>0&&s2[j+1]!=s2[i])j=nxt[j];
if(s2[j+1]==s2[i])j++;
nxt[i]=j;
}
for(int i=1,j=0;i<=n;i++){
while(j>0&&s2[j+1]!=s1[i])j=nxt[j];
if(s2[j+1]==s1[i])j++;
if(j==m){
s[i]=1;
j=nxt[j];
}
ha[i]=ha[i-1]*b+s1[i];
ss[i]=(ss[i-1]+s[i]*i%p)%p;
}
for(int i=n;i>=1;i--)sss[i]=(sss[i+1]+s[i]*(n-i+1)%p)%p,hb[i]=hb[i+1]*b+s1[i];
for(int i=1;i<=n;i++)s[i]+=s[i-1];
for(int i=2;i<n;i++){
int l=0,r=min(i,n-i+1);
while(l+1<r){
int mid=(l+r)/2;
if(ha[i]-ha[i-mid-1]*t[mid+1]==hb[i]-hb[i+mid+1]*t[mid+1])l=mid;
else r=mid;
}
ans=(ans+gets(i-l+m-1,i+l))%p;
}
printf("%lld\n",ans);
}