洛谷能过,但是本校 OJ 上 TLE 过不去,求助
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<complex>
const int mod=998244353,g=3,invg=998244354/3;
using namespace std;
int f[200005],a[200005],n,k,tr[200005],wk[200005],tmp[200005],gy[200005];
int s[200005]={1},t[200005]={0,1};
int Power(int x,int y,int mod) {
int ret=1;
while(y) {
if(y&1)ret=1ll*ret*x%mod;
x=1ll*x*x%mod,y>>=1;
}
return ret;
}
void GetTr(int l) {
for(int i=0; i<l; i++)tr[i]=(tr[i>>1]>>1)|((i&1)?(l>>1):0);
}
void NTT(int a[],int n,int flag) {
for(int i=0; i<n; i++)if(tr[i]<i)swap(a[i],a[tr[i]]);
for(int i=1; i<n; i<<=1) {
int w=Power(flag==1?g:invg,(mod-1)/(i<<1),mod);
wk[0]=1;
for(int j=1; j<i; j++)wk[j]=1ll*wk[j-1]*w%mod;
for(int j=0; j<n; j+=(i<<1)) {
for(int k=0; k<i; k++) {
int t=1ll*wk[k]*a[i+j+k]%mod;
a[i+j+k]=(a[j+k]-t+mod)%mod;
a[j+k]=(a[j+k]+t)%mod;
}
}
}
if(flag==-1)for(int i=0,t=Power(n,mod-2,mod); i<n; i++)a[i]=1ll*a[i]*t%mod;
}
void Calcinv(int a[],int b[],int l){
if(l==1){
b[0]=Power(a[0],mod-2,mod);
return ;
}
Calcinv(a,b,(l+1)/2);
int len=1;
while(len<l*2)len<<=1;
GetTr(len);
memcpy(tmp,a,sizeof(int)*l);
for(int i=l;i<len;i++)tmp[i]=0;
NTT(tmp,len,1),NTT(b,len,1);
for(int i=0;i<len;i++)b[i]=(2-1ll*tmp[i]*b[i]%mod+mod)%mod*b[i]%mod;
NTT(b,len,-1);
for(int i=l;i<len;i++)b[i]=0;
}
void CalcMod(int f[],int l[],int n,int m,int c[]){
if(n<m){
memset(c,0,sizeof(c));
for(int i=0;i<m;i++)c[i]=f[i];
return ;
}
static int t[200005];
memset(t,0,sizeof(t));
reverse(l,l+m+1);
reverse(f,f+n+1);
Calcinv(l,t,n-m+1);
int len=1;
while(len<=2*n)len<<=1;
GetTr(len);
NTT(t,len,1),NTT(f,len,1);
for(int i=0;i<len;i++)t[i]=1ll*t[i]*f[i]%mod;
NTT(t,len,-1);
reverse(t,t+n-m+1);
for(int i=n-m+1;i<len;i++)t[i]=0;
reverse(l,l+m+1);
NTT(t,len,1),NTT(l,len,1);
for(int i=0;i<len;i++)t[i]=1ll*t[i]*l[i]%mod;
NTT(t,len,-1),NTT(f,len,-1),NTT(l,len,-1),reverse(f,f+n+1);
memset(c,0,sizeof(c));
for(int i=0;i<m;i++)c[i]=(f[i]-t[i]+mod)%mod;
}
int main() {
//freopen("1.in","r",stdin);
scanf("%d%d",&n,&k);
for(int i=1;i<=k;i++)scanf("%d",&f[i]),f[i]=(f[i]%mod+mod)%mod;
for(int i=0;i<k;i++)scanf("%d",&a[i]),a[i]=(a[i]%mod+mod)%mod;
reverse(f,f+k+1),f[k]=1;
for(int i=0;i<k;i++)f[i]=mod-f[i];
int len=1;
while(len<=2*k)len<<=1;
int css=0,cst=1;
while(n){
GetTr(len);
NTT(t,len,1);
if(n&1){
NTT(s,len,1);
for(int i=0;i<len;i++)s[i]=1ll*s[i]*t[i]%mod;
NTT(s,len,-1),css+=cst;
CalcMod(s,f,css,k,gy),css=min(css,k-1);
memcpy(s,gy,sizeof(gy));
}
for(int i=0;i<len;i++)t[i]=1ll*t[i]*t[i]%mod;
GetTr(len);
NTT(t,len,-1),cst*=2;
CalcMod(t,f,cst,k,gy),cst=min(cst,k-1);
memcpy(t,gy,sizeof(gy));
n>>=1;
//cout<<n<<endl;
}
int ans=0;
for(int i=0;i<k;i++)ans=(ans+1ll*s[i]*a[i])%mod;
cout<<ans;
return 0;
}