求卡常
查看原帖
求卡常
42156
feecle6418机器人楼主2020/7/17 12:06

洛谷能过,但是本校 OJ 上 TLE 过不去,求助

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<complex>
const int mod=998244353,g=3,invg=998244354/3;
using namespace std;
int f[200005],a[200005],n,k,tr[200005],wk[200005],tmp[200005],gy[200005];
int s[200005]={1},t[200005]={0,1};
int Power(int x,int y,int mod) {
	int ret=1;
	while(y) {
		if(y&1)ret=1ll*ret*x%mod;
		x=1ll*x*x%mod,y>>=1;
	}
	return ret;
}
void GetTr(int l) {
	for(int i=0; i<l; i++)tr[i]=(tr[i>>1]>>1)|((i&1)?(l>>1):0);
}
void NTT(int a[],int n,int flag) {
	for(int i=0; i<n; i++)if(tr[i]<i)swap(a[i],a[tr[i]]);
	for(int i=1; i<n; i<<=1) {
		int w=Power(flag==1?g:invg,(mod-1)/(i<<1),mod);
		wk[0]=1;
		for(int j=1; j<i; j++)wk[j]=1ll*wk[j-1]*w%mod;
		for(int j=0; j<n; j+=(i<<1)) {
			for(int k=0; k<i; k++) {
				int t=1ll*wk[k]*a[i+j+k]%mod;
				a[i+j+k]=(a[j+k]-t+mod)%mod;
				a[j+k]=(a[j+k]+t)%mod;
			}
		}
	}
	if(flag==-1)for(int i=0,t=Power(n,mod-2,mod); i<n; i++)a[i]=1ll*a[i]*t%mod;
}
void Calcinv(int a[],int b[],int l){
	if(l==1){
		b[0]=Power(a[0],mod-2,mod);
		return ;
	}
	Calcinv(a,b,(l+1)/2);
	int len=1;
	while(len<l*2)len<<=1;
	GetTr(len);
	memcpy(tmp,a,sizeof(int)*l);
	for(int i=l;i<len;i++)tmp[i]=0;
	NTT(tmp,len,1),NTT(b,len,1);
	for(int i=0;i<len;i++)b[i]=(2-1ll*tmp[i]*b[i]%mod+mod)%mod*b[i]%mod;
	NTT(b,len,-1);
	for(int i=l;i<len;i++)b[i]=0;
}
void CalcMod(int f[],int l[],int n,int m,int c[]){
	if(n<m){
		memset(c,0,sizeof(c));
		for(int i=0;i<m;i++)c[i]=f[i];
		return ;
	}
	static int t[200005];
	memset(t,0,sizeof(t));
	reverse(l,l+m+1);
	reverse(f,f+n+1);
	Calcinv(l,t,n-m+1);
	int len=1;
	while(len<=2*n)len<<=1;
	GetTr(len);
	NTT(t,len,1),NTT(f,len,1);
	for(int i=0;i<len;i++)t[i]=1ll*t[i]*f[i]%mod;
	NTT(t,len,-1);
	reverse(t,t+n-m+1);
	for(int i=n-m+1;i<len;i++)t[i]=0;
	reverse(l,l+m+1);
	NTT(t,len,1),NTT(l,len,1);
	for(int i=0;i<len;i++)t[i]=1ll*t[i]*l[i]%mod;
	NTT(t,len,-1),NTT(f,len,-1),NTT(l,len,-1),reverse(f,f+n+1);
	memset(c,0,sizeof(c));
	for(int i=0;i<m;i++)c[i]=(f[i]-t[i]+mod)%mod;
}
int main() {
	//freopen("1.in","r",stdin);
	scanf("%d%d",&n,&k);
	for(int i=1;i<=k;i++)scanf("%d",&f[i]),f[i]=(f[i]%mod+mod)%mod;
	for(int i=0;i<k;i++)scanf("%d",&a[i]),a[i]=(a[i]%mod+mod)%mod;
	reverse(f,f+k+1),f[k]=1;
	for(int i=0;i<k;i++)f[i]=mod-f[i];
	int len=1;
	while(len<=2*k)len<<=1;
	int css=0,cst=1;
	while(n){
		GetTr(len);
		NTT(t,len,1);
		if(n&1){
			NTT(s,len,1);
			for(int i=0;i<len;i++)s[i]=1ll*s[i]*t[i]%mod;
			NTT(s,len,-1),css+=cst;
			CalcMod(s,f,css,k,gy),css=min(css,k-1);
			memcpy(s,gy,sizeof(gy));
		}
		for(int i=0;i<len;i++)t[i]=1ll*t[i]*t[i]%mod;
		GetTr(len);
		NTT(t,len,-1),cst*=2;
		CalcMod(t,f,cst,k,gy),cst=min(cst,k-1);
		memcpy(t,gy,sizeof(gy));
		n>>=1;
		//cout<<n<<endl;
	}
	int ans=0;
	for(int i=0;i<k;i++)ans=(ans+1ll*s[i]*a[i])%mod;
	cout<<ans;
	return 0;
}
2020/7/17 12:06
加载中...