被卡常,求助
查看原帖
被卡常,求助
99506
_LHF_楼主2021/3/23 21:27

蒟蒻的代码被卡常了,有7个点TLE了,想知道有没有不开 O2 的方法加速(毕竟考试时不给用 O2 )

#include<cstdio>
#include<algorithm>
#define N 1050000
#define ll long long
using namespace std;
const int mod=998244353;
inline int Mod(int a){return a<mod?a:a-mod;}
inline int fastpow(int a,int b)
{
	int s=1;
	while(b)
	{
		if(b&1) s=1ll*s*a%mod;
		a=1ll*a*a%mod, b>>=1;
	}
	return s;
}
int r[N],inv3=fastpow(3,mod-2),Wn[N],iWn[N];
inline void ntt(int *a,int lim,bool type)
{
	register int x,y,wn;
	for(int i=1;i<lim;i++)
		if(i<r[i]) swap(a[i],a[r[i]]);
	for(register int i=1;i<lim;i<<=1)
	{
		wn=type?Wn[i<<1]:iWn[i<<1];
		for(register int j=0;j<lim;j+=(i<<1))
			for(register int k=0,w=1;k<i;k++,w=1ll*w*wn%mod)
			{
				x=a[j+k];
				y=1ll*a[i+j+k]*w%mod;
				a[j+k]=Mod(x+y);
				a[i+j+k]=Mod(x-y+mod);
			}
	}
	if(type) return;
	x=fastpow(lim,mod-2);
	for(register int i=0;i<lim;i++)
		a[i]=1ll*a[i]*x%mod;
}
inline int init(int n)
{
	register int lim=1,l=0;
	while(lim<=n) lim<<=1,l++;
	for(int i=1;i<lim;i++)
		r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
	return lim;
}
ll fac[N],ifac[N];
int p[N],a[N],b[N];
inline void DFT(int*a,int n)
{
	register int len=init(n*2),i;
	for(i=0;i<=n;i++) p[i]=ifac[i];
	for(i=n+1;i<len;i++) p[i]=0;
	ntt(a,len,1);
	ntt(p,len,1);
	for(i=0;i<len;i++) a[i]=1ll*a[i]*p[i]%mod;
	ntt(a,len,0);
	for(i=0;i<=n;i++) a[i]=a[i]*fac[i]%mod;
	for(i=n+1;i<len;i++) a[i]=0;
}
inline void IDFT(int*a,int n)
{
	register int len=init(n*2),i;
	for(i=0;i<=n;i++) p[i]=(i%2?mod-ifac[i]:ifac[i]);
	for(i=0;i<len;i++) a[i]=a[i]*ifac[i]%mod;
	for(i=n+1;i<len;i++) a[i]=p[i]=0;
	ntt(a,len,1);
	ntt(p,len,1);
	for(i=0;i<len;i++) a[i]=1ll*a[i]*p[i]%mod;
	ntt(a,len,0);
	for(i=n+1;i<len;i++) a[i]=0;
}
int n,m,mx;
int main()
{
	for(register int i=1;i<=1e6;i<<=1) Wn[i]=fastpow(3,(mod-1)/i);
	for(register int i=1;i<=1e6;i<<=1) iWn[i]=fastpow(inv3,(mod-1)/i);//预处理原根
	scanf("%d%d",&n,&m);
	for(register int i=0;i<=n;i++) scanf("%d",&a[i]);
	for(register int i=0;i<=m;i++) scanf("%d",&b[i]);
	mx=(n+m)*2, fac[0]=ifac[0]=1;
	for(register int i=1;i<=mx+1;i++) fac[i]=fac[i-1]*i%mod;
	ifac[mx+1]=fastpow(fac[mx+1],mod-2);
	for(register int i=mx;i;i--) ifac[i]=ifac[i+1]*(i+1)%mod;
	int len=n+m;
	DFT(a,len);
	DFT(b,len);
	for(register int i=0;i<=len;i++) a[i]=1ll*a[i]*b[i]%mod;
	IDFT(a,len);
	for(register int i=0;i<=len;i++) printf("%d ",a[i]);
	return 0;
}

请大佬指点一二,谢谢

2021/3/23 21:27
加载中...