萌新60求助
查看原帖
萌新60求助
341034
abcde777楼主2021/3/17 21:46
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#define I inline
#define int long long
#define nnq isnnq
#define mod 998244353
using namespace std ;

const int INVG = 332748118;
const int inv2 = 499122177;
const int N = 300100 ;
const int G = 3 ;
int A[N],B[N],c[N],y[N],rev[N],ANS[N],RET[N],ma[N],mb[N],inv[N],revt[N] ;
int _n,_m ;

I int qpow(int x,int y) {
	if(y==0) return 1;
	if(y==1) return x;
	int mk=qpow(x,y/2)%mod;
	return y%2 ? 1ll*mk*mk%mod*x%mod:1ll*mk*mk%mod;
}
void ntt(int *x,int n,int type) {
	for(int i=0;i<n;++i) if(rev[i]<i) swap(x[i],x[rev[i]]) ;
	for(int i=1;i<n;i<<=1) {
		int w = qpow(G,(mod-1)/(i<<1)) ;
		if(type!=1) w=qpow(INVG,(mod-1)/(i<<1)) ;
		for(int j=0;j<n;j+=(i<<1)) {
			int ng=1,mk1,mk2;
			for(int p=0;p<i;++p,ng=ng*w%mod) {
			 	mk1=x[j+p]; mk2=x[j+p+i];
				x[j+p]=(mk1+mk2*ng%mod)%mod;	
				x[j+p+i]=(mk1-mk2*ng%mod+mod)%mod;	
			} 
		}
	}
	if(type==1) return ;
	int ny=qpow(n,mod-2);
	for(int i=0;i<n;++i) x[i]=x[i]*ny%mod;
}
I void getn(int n,int *a,int *b) { 
	if(n==1) {
		b[0]=qpow(a[0],mod-2) ;
		return ;
	}
	getn((n+1)>>1,a,b);
	int rn=1,len=0; while((n<<1)>rn) rn<<=1,len++; 
	for(int i=1;i<rn;++i) 
		rev[i]=(rev[i>>1]>>1)|((i&1)<<len-1);
	for(int i=0;i<n;++i) c[i]=a[i];
	for(int i=n;i<rn;++i) c[i]=0;
	ntt(c,rn,1); ntt(b,rn,1);
	for(int i=0;i<rn;++i) b[i]=b[i]*1ll*((1ll*2-1ll*c[i]*b[i]%mod+mod)%mod)%mod;
	ntt(b,rn,-1); 
	for(int i=n;i<rn;++i) b[i]=0;
}
I void reva(int *aa,int num) {
	for(int i=0;i<num;++i) revt[i]=aa[i];
	for(int i=0;i<num;++i) aa[i]=revt[num-i-1];
}
I void divi(int n,int m,int *a,int *b) {
	for(int i=0;i<n;++i) ma[i]=a[i];
	for(int i=0;i<m;++i) mb[i]=b[i];
	reva(ma,n); reva(mb,m);
	getn(n-m+1,mb,inv);
	int rn=1,len=0; while((n<<1)>rn) rn<<=1,len++; 
	for(int i=1;i<rn;++i) 
		rev[i]=(rev[i>>1]>>1)|((i&1)<<len-1);
	ntt(ma,rn,1);
	ntt(inv,rn,1);
	for(int i=0;i<rn;++i) ANS[i]=ma[i]*inv[i]%mod;
	ntt(ANS,rn,-1); reva(ANS,n-m+1);
	for(int i=0;i<n-m+1;++i) cout<<ANS[i]<<' '; cout<<endl; //这里输出的都还是对的
	ntt(b,rn,1);
	ntt(ANS,rn,1);
	for(int i=0;i<rn;++i) ANS[i]=ANS[i]%mod*b[i]%mod;
	ntt(ANS,rn,-1);
	for(int i=0;i<rn;++i) RET[i]=(a[i]%mod-ANS[i]%mod+mod)%mod;
	for(int i=0;i<m-1;++i) cout<<RET[i]<<' '; cout<<endl;
}
I int read() {
	int w=1,ret=0; char ch;
	while((ch=getchar())>'9'||ch<'0'&&ch!='-'); if(ch=='-') w=-1; else ret=ch-'0';
	while((ch=getchar())>='0'&&ch<='9') ret=ret*10+ch-'0';
	return ret*w;
}
signed main()
{
	_n=read()+1; _m=read()+1;
	for(int i=0;i<_n;++i) A[i]=read()%mod;
	for(int i=0;i<_m;++i) B[i]=read()%mod;
	divi(_n,_m,A,B);
	return 0 ;
}

2021/3/17 21:46
加载中...