求助
查看原帖
求助
160839
Prean楼主2020/8/6 18:03

RT,NTT是对的,单次times也是对的,单次inv也是对的,求助/kel

code:

#include<cstring>
#include<cstdio>
#define clr(f,len) memset(f,0,(len)<<2)
#define cpy(f,g,len) memcpy(f,g,(len)<<2)
const int M=1e6+5,mod=998244353,G=3,invG=332748118;
int n,m,fl,f[M],g[M],t[M];
inline void swap(int&a,int&b){
    a^=b^=a^=b;
}
inline int pow(int a,int b=mod-2){
    int ans=1;
    for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)ans=1ll*a*ans%mod;
    return ans;
}
inline void reverse(int*f,int len){
    for(int i=0;(i<<1|1)<len;++i)swap(f[i],f[len-i-1]);
}
inline void px(int*f,int*g,int len){
    for(int i=0;i<len;++i)f[i]=1ll*f[i]*g[i]%mod;
}
inline void rev(int len){
    for(int i=0;i<len;++i)t[i]=t[i>>1]>>1|(i&1?len>>1:0);
}
void NTT(int*f,bool flag,int n){
    int i,p,k,w,w1,len;
    if(n!=fl)rev(fl=n);
    for(i=0;i<n;++i)if(i<t[i])swap(f[i],f[t[i]]);
    for(p=2;p<=n;p<<=1){
        len=p>>1;w1=pow(flag?G:invG,(mod-1)/p);
        for(k=0;k<n;k+=p){
            w=1;
            for(i=k;i<k+len;++i){
                int t=1ll*f[i+len]*w%mod;
                if((f[i+len]=f[i]-t)<0)f[i+len]+=mod;
                if((f[i]=f[i]+t)>=mod)f[i]-=mod;
                w=1ll*w*w1%mod;
            }
        }
    }
    if(flag)return;
    int inv=pow(n);
    for(i=0;i<n;++i)f[i]=1ll*f[i]*inv%mod;
}
void times(int*f,int*g,int l1,int l2){
    static int sav[M];
    int i,n=1;
    while(n<=(l1+l2))n<<=1;
    clr(f+l1,n-l1);cpy(sav,g,l2);
    NTT(f,1,n);NTT(sav,1,n);
    px(f,sav,n);NTT(f,0,n);
    clr(sav,n);
}
void inv(int*f,int m){
    static int b1[M],b2[M],b3[M];
    int i,n=1,len;
    while(n<m)n<<=1;b1[0]=pow(f[0]);
    for(len=2;len<=n;len<<=1){
        for(i=0;i<(len>>1);++i)b3[i]=(b1[i]<<1)%mod;
        cpy(b2,f,len);
        NTT(b1,1,len<<1);NTT(b2,1,len<<1);
        px(b1,b1,len<<1);px(b1,b2,len<<1);
        NTT(b1,0,len<<1);clr(b1+len,len);
        for(i=0;i<len;++i)b1[i]=(b3[i]-b1[i]+mod)%mod;
    }
    cpy(f,b1,m);clr(b1,n+n);clr(b2,n+n);clr(b3,n+n);
}
void div(int*f,int*g,int n,int m){
    static int b1[M],b2[M];
    int len=n-m+1;
    reverse(f,n);cpy(b1,f,len);inv(b1,len);
    times(b1,g,len,len);clr(b1+len,len);
    reverse(b1,len);
    cpy(b2,b1,len);times(b2,g,len,m);
    for(int i=0;i<m-1;++i)b2[i]=(f[i]-b2[i]+mod)%mod;
    cpy(f,b1,len);clr(f+len,m);
    cpy(g,b2,m-1);
}
signed main(){
    int i;
    scanf("%d%d",&n,&m);++n;++m;
    for(i=0;i<n;++i)scanf("%d",f+i);
    for(i=0;i<m;++i)scanf("%d",g+i);
    div(f,g,n,m);
    for(i=0;i<(n-m+1);++i)printf("%d ",f[i]);printf("\n");
    for(i=0;i<m-1;++i)printf("%d ",g[i]);
}
2020/8/6 18:03
加载中...