ntt挂了,求助一下(
查看原帖
ntt挂了,求助一下(
264548
Tangent233楼主2021/6/12 16:53
#include<bits/stdc++.h>
using namespace std;
const int maxn=(1<<18)+10,g=3,mod=998244353;
long long S[maxn],H[maxn],C[maxn],D[maxn];
int lim=1,rev[maxn],len;
inline long long quickpow(long long n,long long k)
{
    long long ans=1;
    while(k)
    {
        if(k&1) ans=(ans*n)%mod;
        n=(n*n)%mod;
        k>>=1;
    }
    return ans;
}
inline void ntt(long long *arr,int f)
{
    for(int i=0;i<lim;i++)
        if(i<rev[i]) swap(arr[i],arr[rev[i]]);
    for(int i=1;i<lim;i<<=1)
    {
        long long gn=quickpow(g,(mod-1)/(i<<1));
        if(f==-1) gn=quickpow(gn,mod-2);
        for(int j=0;j<lim;j+=(i<<1))
        {
            long long w=1;
            for(int k=0;k<i;k++)
            {
                long long x=arr[j+k],y=w*arr[j+k+i]%mod;
                arr[j+k]=(x+y)%mod;
                arr[j+k+i]=((x-y)%mod+mod)%mod;
                w=(w*gn)%mod;
            }
        }
    }
    if(f==-1)
    {
        int invn=quickpow(lim,mod-2);
        for(int i=0;i<=lim;i++) arr[i]=(arr[i]*invn)%mod;
    }
}
inline int read()
{
	int x=0,f=1;char ch=getchar();
	while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
	while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
	return x*f;
}
const int aaaa=5e4+10;
bool prime[aaaa];
void aishishai(int n)
{
    memset(prime,1,sizeof(prime));
    for(int i=2;i<=n;i++)
        if(prime[i])
        {
            int tmp=i+i;
            while(tmp<=n)
            {
                prime[tmp]=0;
                tmp+=i;
            }
        }
}
string ax;
int main()
{
    aishishai(aaaa-10);
    int a,b,c;
    while(1)
    {
        memset(S,0,sizeof(S));
        memset(H,0,sizeof(H));
        memset(C,0,sizeof(C));
        memset(D,0,sizeof(D));
        memset(rev,0,sizeof(rev));
        lim=1,len=0;
        a=read(),b=read(),c=read();
        if(a==0&&b==0&&c==0) return 0;
        for(int i=0;i<=b;i++)
            if(!prime[i]) S[i]=H[i]=C[i]=D[i]=1;
        for(lim=1;lim<=b*4;lim<<=1) len++;
        for(int i=0;i<lim;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
        for(int i=1;i<=c;i++)
        {
            cin>>ax;
            int oo=ax[0]-'0';
            char pp=ax[1];
            if(pp=='S') S[oo]=0;
            if(pp=='H') H[oo]=0;
            if(pp=='C') C[oo]=0;
            if(pp=='D') D[oo]=0;
        }
        ntt(S,1),ntt(H,1),ntt(C,1),ntt(D,1);
        for(int i=0;i<=lim;i++)
            S[i]=S[i]*H[i]*C[i]*D[i]%mod;
        ntt(S,-1);
        for(int i=a;i<=b;i++)
            cout<<S[i]<<endl;
    }
    return 0;
}

这是为什么啊

2021/6/12 16:53
加载中...