萌新刚学多项式,求助AC代码开O2后全RE
查看原帖
萌新刚学多项式,求助AC代码开O2后全RE
108111
Lumos壹玖贰壹楼主2021/3/4 11:10

RT 刚学多项式,过了之后想开下O2看与大佬的差距
然后全RE
应该是哪挂了,求助大佬QAQ
没开O2
开了O2

#include<bits/stdc++.h>
#define ll long long
#define ri register int
using namespace std;
const int p = 998244353,_G = 3,maxn = (1<<21)+1;
inline int rd(){
	int res = 0,f = 0;
	char ch = getchar();
	for(;!isdigit(ch);ch = getchar()) if(ch =='-') f  = 1;
	for(;isdigit(ch);ch = getchar()) res = (res<<3) + (res<<1) + ch - 48;
	return f ?-res:res;
}
inline ll qp(ll x,ll k){
	ll res = 1;
	while(k){
		if(k & 1) res = res * x % p;
		x = x * x % p;
		k >>= 1;
	}
	return res;
}
const ll inv2 = qp(2,p - 2);
ll F[maxn],G[maxn],gn[maxn],a[maxn],b[maxn];
int n,r[maxn];
inline ll ntt(ll *f,int lim,int type){
	for(ri i = 1;i < lim;++i) if(i < r[i]) swap(f[i],f[r[i]]);
	for(ri i = 2;i <= lim;i <<= 1){
		ll g = gn[i];
		for(ri j = 0;j < lim;j += i){
			ll gi = 1;
			for(ri k = j;k < j + (i >> 1);++k,gi = gi * g % p){
				ll x = f[k],y = f[k + (i >> 1)] * gi % p;
				f[k] = (x + y) % p;
				f[k + (i >> 1)] = (x - y + p) % p;
			}
		}
	}
	if(type == -1){
		ll inv = qp(lim,p-2);
		reverse(f + 1,f + lim);
		for(ri i = 0;i < lim;++i)
			f[i] = f[i] * inv % p;
	}
}
void getinv(ll *f,ll *g,int deg){
	if(deg == 1){
		g[0] = qp(f[0],p-2);
		return;
	}
	int mid = deg + 1 >> 1;
	getinv(f,g,mid);
	int lim = 1,l = 0;
	while(lim < (deg<<1)) lim <<= 1,++l;
	for(ri i = 1;i < lim;++i) r[i] = (r[i >> 1] >> 1) | (i&1 ? (lim >> 1) : 0);
	for(ri i = 0;i < deg;++i) a[i] = f[i]; for(ri i = deg;i < lim;++i) a[i] = 0;
	ntt(a,lim,1); ntt(g,lim,1);
	for(ri i = 0;i < lim;++i) g[i] = g[i] * (2 - a[i] * g[i] % p + p) % p;
	ntt(g,lim,-1); for(ri i = deg;i < lim;++i) g[i] = 0;
	return;
}
void getsqrt(ll *f,ll *g,int deg){
	if(deg == 1){
		g[0] = 1;
		return;
	}
	int mid = deg + 1 >> 1;
	int lim = 1,l = 0;
	while(lim < (deg<<1)) lim <<= 1,++l;
	getsqrt(f,g,mid);
	
	for(ri i = 0;i < lim;++i) b[i] = 0;
	
	getinv(g,b,deg);
	for(ri i = 1;i < lim;++i) r[i] = (r[i>>1]>>1) | (i&1 ? lim>>1 : 0);
	for(ri i = 0;i < deg;++i) a[i] = f[i];
	for(ri i = deg;i < lim;++i) a[i] = b[i]= g[i] = 0;
	ntt(a,lim,1);  ntt(b,lim,1); ntt(g,lim,1);
	for(ri i = 0;i < lim;++i) g[i] = inv2 * ((a[i] * b[i] % p + g[i]) % p) % p;
	ntt(g,lim,-1); for(ri i = deg;i < lim;++i) g[i] = 0;
}
int main(){
	n = rd();
	for(ri i = 0;i < n;++i) F[i] = rd();
	for(ri i = 2;i < maxn;i <<= 1) gn[i] = qp(_G,(p-1)/i);
	getsqrt(F,G,n);
	for(ri i = 0;i < n;++i) printf("%lld ",G[i]); puts("");
	return 0;
}
2021/3/4 11:10
加载中...