生成函数求助 RE on test #10
查看原帖
生成函数求助 RE on test #10
760655
tkdqmx楼主2025/6/26 09:51
#include<bits/stdc++.h>
using namespace std;
#define N 2000005
#define LL long long
LL n,m,r[N],fac[N],inv[N];
const LL mod=998244353,p=3,pi=332748118,inv2=499122177;
LL quick_pow(LL x,LL y,LL res=1){
    for(;y;y>>=1,x=x*x%mod)  if(y&1)  res=res*x%mod;
    return res;
}
struct poly{
    vector<LL> v;
    poly(int n=0){v.resize(n+1,0);}
    static void NTT(vector<LL> &a,int n,int type){
        for(int i=0;i<n;i++)  if(i<r[i])  swap(a[i],a[r[i]]);
        for(int mid=1;mid<n;mid<<=1){
            LL wn=quick_pow(type==1?p:pi,(mod-1)/(mid<<1));
            for(int i=0;i<n;i+=(mid<<1)){
                LL w=1;
                for(int j=0;j<mid;j++,w=w*wn%mod){
                    LL x=a[i+j],y=w*a[i+j+mid]%mod;
                    a[i+j]=(x+y)%mod,a[i+j+mid]=(x-y+mod)%mod;
                }
            }
        }
    }
    friend poly operator+(poly a, poly b){
		int n=max(a.v.size(),b.v.size());a.v.resize(n);b.v.resize(n);
		for(int i=0;i<n;i++)  a.v[i]=(a.v[i]+b.v[i])%mod;
		return a;
	}
	friend poly operator-(poly a, poly b){
		int n=max(a.v.size(),b.v.size());a.v.resize(n);b.v.resize(n);
		for(int i=0;i<n;i++)  a.v[i]=(a.v[i]-b.v[i]+mod)%mod;
		return a;
	}
    friend poly operator*(poly a,poly b){
        int limit=1,l=0,n=a.v.size()-1,m=b.v.size()-1;
        while(limit<=n+m)  limit<<=1,l++;
        a.v.resize(limit,0),b.v.resize(limit,0);
        for(int i=0;i<limit;i++)  r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
        NTT(a.v,limit,1),NTT(b.v,limit,1);
        for(int i=0;i<limit;i++)  a.v[i]=a.v[i]*b.v[i]%mod;
        NTT(a.v,limit,-1);LL inv=quick_pow(limit,mod-2);
        for(int i=0;i<=n+m;i++)  a.v[i]=a.v[i]*inv%mod;
        a.v.resize(n+m+1);return a;
    }
    static poly Mod(poly a,int n){a.v.resize(n,0);return a;}
    static poly inv(poly a){
        poly A(0),B(0);A.v[0]=quick_pow(a.v[0],mod-2),B.v[0]=2;
		for(int i=1;(1<<i-1)<a.v.size();i++)  A=Mod(A*(B-Mod(a,1<<i)*A),1<<i);
		return A;
    }
    static poly sqrt(poly a){
        int n=a.v.size();
        poly b(0),A(0),B(0);b.v[0]=1;
        for(int len=1;len<(n<<1);len<<=1){
            b.v.resize(len,0),A.v.resize(len,0),B.v.resize(len,0);
            for(int i=0;i<len;i++)  A.v[i]=a.v[i];
            B=inv(b),B.v.resize(len,0),A=A*B;
            for(int i=0;i<len;i++)  b.v[i]=(b.v[i]+A.v[i])%mod*inv2%mod;
        }
        b.v.resize(n,0);return b;
    }
};
int main(){
    scanf("%lld%lld",&n,&m);poly a(m);
    for(int i=1,x;i<=n;i++){
        scanf("%d",&x);
        if(x<=m)  a.v[x]=1;
    }
    for(int i=0;i<=m;i++)  a.v[i]=(-4*a.v[i]%mod+mod)%mod;
    a.v[0]=1,a=a.sqrt(a),a.v[0]=(a.v[0]+1)%mod,a=a.inv(a);
    // a=2*a.inv(1+a.sqrt(1-4*a));
    for(int i=1;i<=m;i++)  printf("%lld\n",a.v[i]*2%mod);
}
2025/6/26 09:51
加载中...