萌新刚学OI,求助WA 20pts
查看原帖
萌新刚学OI,求助WA 20pts
341373
Autofreeze楼主2021/2/21 11:50

RT,跑15次 NTT 然后合并,目测好像合并的时候错了

屑代码

#include<bits/stdc++.h>
#define N 2001001
#define re register
#define MAX 2001
#define eps 1e-10
using namespace std;
typedef long long ll;
typedef double db;
const ll mod1=998244353,mod2=1004535809, mod3=409762049,g=3;
inline void read(re ll &ret)
{
	ret=0;re bool pd=false;re char c=getchar();
	while(!isdigit(c)){(c=='-')&&(pd=true);c=getchar();}
	while(isdigit(c)){ret=(ret<<1)+(ret<<3)+(c^48);c=getchar();}	
	ret=pd?-ret:ret;
	return;
}
ll n,m,p,a[N],b[N],ans1[N],ans2[N],ans3[N],num,rev[N],inv1g,inv2g,inv3g;
inline ll qpow(re ll a,re ll b,re ll p)
{
	re ll ret=1;
	while(b)
	{
		if(b&1)
			ret*=a,ret%=p;
		b>>=1;
		a*=a;
		a%=p;
	}
	return ret%p;
}
inline ll inv(re ll x,re ll p)
{
	return qpow(x,p-2,p);
}
inline void ntt1(re ll a[],re ll n,re ll typ)
{
	re ll num=1,bit=0;
	while(num<n)
	{
		num<<=1;
		bit++;
	}
	for(re int i=0;i<n;i++)
	{
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	}
	for(re int mid=1;mid<n;mid<<=1)
	{
		re ll wn=qpow((typ==1)?g:inv1g,(mod1-1)/(mid<<1),mod1);
		for(re int j=0;j<n;j+=mid<<1)
		{
			re ll w=1;
			for(re int i=0;i<mid;i++,w*=wn,w%=mod1)
			{
				re ll x=a[i+j],y=a[i+j+mid]*w%mod1;
				a[i+j]=(x+y)%mod1,a[i+j+mid]=(x-y+mod1)%mod1;
			}
		}
	}
	return;
}
inline void ntt2(re ll a[],re ll n,re ll typ)
{
	re ll num=1,bit=0;
	while(num<n)
	{
		num<<=1;
		bit++;
	}
	for(re int i=0;i<n;i++)
	{
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	}
	for(re int mid=1;mid<n;mid<<=1)
	{
		re ll wn=qpow((typ==1)?g:inv2g,(mod2-1)/(mid<<1),mod2);
		for(re int j=0;j<n;j+=mid<<1)
		{
			re ll w=1;
			for(re int i=0;i<mid;i++,w*=wn,w%=mod2)
			{
				re ll x=a[i+j],y=a[i+j+mid]*w%mod2;
				a[i+j]=(x+y)%mod2,a[i+j+mid]=(x-y+mod2)%mod2;
			}
		}
	}
	return;
}
inline void ntt3(re ll a[],re ll n,re ll typ)
{
	re ll num=1,bit=0;
	while(num<n)
	{
		num<<=1;
		bit++;
	}
	for(re int i=0;i<n;i++)
	{
		rev[i]=(rev[i>>1]>>1)|((i&1)<<bit-1);
		if(i<rev[i])
			swap(a[i],a[rev[i]]);
	}
	for(re int mid=1;mid<n;mid<<=1)
	{
		re ll wn=qpow((typ==1)?g:inv3g,(mod3-1)/(mid<<1),mod3);
		for(re int j=0;j<n;j+=mid<<1)
		{
			re ll w=1;
			for(re int i=0;i<mid;i++,w*=wn,w%=mod3)
			{
				re ll x=a[i+j],y=a[i+j+mid]*w%mod3;
				a[i+j]=(x+y)%mod3,a[i+j+mid]=(x-y+mod3)%mod3;
			}
		}
	}
	return;
}
inline void solve1()
{
	ntt1(a,num,1);
	ntt1(b,num,1);
	for(re int i=0;i<num;i++)
		ans1[i]=a[i]*b[i]%mod1;
	ntt1(ans1,num,-1);
	ntt1(a,num,-1);
	ntt1(b,num,-1);
	for(re int i=0;i<=n;i++)
		a[i]=a[i]*inv(num,mod1)%mod1;
	for(re int i=0;i<=m;i++)
		b[i]=b[i]*inv(num,mod1)%mod1;
	return;
}
inline void solve2()
{
	ntt2(a,num,1);
	ntt2(b,num,1);
	for(re int i=0;i<num;i++)
		ans2[i]=a[i]*b[i]%mod2;
	ntt2(ans2,num,-1);
	ntt2(a,num,-1);
	ntt2(b,num,-1);
	for(re int i=0;i<=n;i++)
		a[i]=a[i]*inv(num,mod2)%mod2;
	for(re int i=0;i<=m;i++)
		b[i]=b[i]*inv(num,mod2)%mod2;
	return;
}
inline void solve3()
{
	ntt3(a,num,1);
	ntt3(b,num,1);
	for(re int i=0;i<num;i++)
		ans3[i]=a[i]*b[i]%mod3;
	ntt3(ans3,num,-1);
	ntt3(a,num,-1);
	ntt3(b,num,-1);
	for(re int i=0;i<=n;i++)
		a[i]=a[i]*inv(num,mod3)%mod3;
	for(re int i=0;i<=m;i++)
		b[i]=b[i]*inv(num,mod3)%mod3;
	return;
}
int main()
{
	inv1g=inv(g,mod1);
	inv2g=inv(g,mod2);
	inv3g=inv(g,mod3);
	read(n);
	read(m);
	read(p);
	for(re int i=0;i<=n;i++)
		read(a[i]);
	for(re int i=0;i<=m;i++)
		read(b[i]);
	num=1;
	while(num<n+m+1)
		num<<=1;
	solve1();
	solve2();
	solve3();
	for(re int i=0;i<n+m+1;i++)
	{
		re ll x1=ans1[i]*inv(num,mod1)%mod1,x2=ans2[i]*inv(num,mod2)%mod2,x3=ans3[i]*inv(num,mod3)%mod3;
		re ll k1,k2,k3;
		k1=((x2-x1)%mod2+mod2)%mod2*inv(mod1,mod2)%mod2;
		re ll x4=x1+k1*mod1;
		re ll lcm=mod1*mod2%mod3;
		re ll k4=((x3-x4)%mod3+mod3)%mod3*inv(lcm,mod3)%mod3;
		if(i)
			putchar(' ');
		printf("%lld",(x4%p+(k4%p)*(mod1*mod2%p)%p)%p);
	}
	putchar('\n');
	exit(0);
}
2021/2/21 11:50
加载中...