m>n的情况WA了,求条
查看原帖
m>n的情况WA了,求条
1010650
Expert_Dreamer楼主2025/8/3 16:16
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define N 2000001
const int mod=998244353;
int lim,rev[N],Gi=3,m,C[N],a[N],b[N],d[N],e[N];
int ksm(int a,int k){
    int res=1;
    while(k){
        if(k&1) res=res*a%mod;
        a=a*a%mod;
        k>>=1;
    }
    return res;
}
void ntt(int *A,int typ){
    for(int i=0;i<lim;i++) if(i<rev[i]) swap(A[i],A[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int Wn=ksm(Gi,(mod-1)/(mid<<1));
        for(int j=0;j<lim;j+=mid<<1){
            int W=1;
            for(int k=0;k<mid;k++,W=W*Wn%mod){
                int x=A[j+k],y=W*A[j+k+mid]%mod;
                A[j+k]=(x+y)%mod;
                A[j+k+mid]=(x-y+mod)%mod;
            }
        }
    }
    if(typ==1) return;
    int inv=ksm(lim,mod-2);
    reverse(A+1,A+lim);
    for(int i=0;i<lim;i++) A[i]=A[i]*inv%mod;
}
void GetInv(int *f,int *g,int len)
{
    if(len==1)
    {
        g[0]=ksm(f[0],mod-2);
        return;
    }
    GetInv(f,g,len+1>>1);
    lim=1;
    m=0;
    while(lim<(len<<1))
    {
        lim<<=1;
        m++;
    }
    for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(m-1));
    for(int i=0;i<len;i++) C[i]=f[i];
    for(int i=len;i<lim;i++) C[i]=0;
    ntt(C,1);
    ntt(g,1);
    for(int i=0;i<lim;i++) g[i]=((2ll-g[i]*C[i]%mod)+mod)%mod*g[i]%mod;
    ntt(g,-1);
    for(int i=len;i<lim;i++) g[i]=0;
}
void GetDev(int *f,int *g,int len)
{
    for(int i=1;i<len;i++) g[i-1]=i*f[i]%mod;
    g[len-1]=0;
}
void GetInvDev(int *f,int *g,int len)
{
    for(int i=1;i<len;i++) g[i]=f[i-1]*ksm(i,mod-2)%mod;
    g[0]=0;
}
void GetLn(int *f,int *g,int len)
{
    memset(a,0,sizeof a);
    memset(b,0,sizeof b);
    GetDev(f,a,len);
    GetInv(f,b,len);
    lim=1;
    m=0;
    while(lim<(len<<1))
    {
        lim<<=1;
        m++;
    }
    for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(m-1));
    ntt(a,1);
    ntt(b,1);
    for(int i=0;i<lim;i++) a[i]=a[i]*b[i]%mod;
    ntt(a,-1);
    GetInvDev(a,g,len);
}
void GetExp(int *f,int *g,int len)
{
    if(len==1)
    {
        g[0]=1;
        return;
    }
    GetExp(f,g,len+1>>1);
    lim=1;
    m=0;
    while(lim<(len<<1))
    {
        lim<<=1;
        m++;
    }
    for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(m-1));
    for(int i=0;i<(len<<1);i++) d[i]=e[i]=0;
    GetLn(g,d,len);
    for(int i=0;i<len;i++) e[i]=f[i];
    ntt(g,1);
    ntt(d,1);
    ntt(e,1);
    for(int i=0;i<lim;i++) g[i]=(1ll-d[i]+e[i]+mod)*g[i]%mod;
    ntt(g,-1);
    for(int i=len;i<lim;i++) g[i]=0;
}
void mul(int *a,int *b,int n,int m){
    int L=0;lim=1;
	while(lim<=n+m) lim<<=1,L++;
	for(int i=0;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));	
	ntt(a,1);ntt(b,1);	
	for(int i=0;i<lim;i++) a[i]=(a[i]*b[i])%mod;
	ntt(a,-1);
}
int f[N],g[N],fac[N],inv[N],A[N];
void stirling(int n){
    f[0]=1;
    for(int i=1;i<=n;i++)f[i]=f[i-1]*i%mod;
    f[n]=ksm(f[n],mod-2);
    for(int i=n;i>=1;i--)f[i-1]=f[i]*i%mod;
    for(int i=0;i<=n;i++){
        if(i%2)g[i]=mod-f[i];
        else g[i]=f[i];
        f[i]=f[i]*ksm(i,n)%mod;
    }
    mul(f,g,n,n);
}
int n,s,l,aa,bb,c,B[N],ifac[N];
int cc(int n,int m) { 
    return n<m?0:fac[n]*ifac[m]%mod*ifac[n-m]%mod; 
}
inline void Add(int &x, int y) { x += y, x -= x >= mod ? mod : 0; }
signed main(){
    cin>>n>>m;
    stirling(max(n,m));
    int nn=n,mm=m;
    fac[0] = ifac[0] = fac[1] = ifac[1] = inv[1] = 1;
	for(int i = 2; i <= n+m; i++) {
		fac[i] = 1LL * fac[i - 1] * i % mod;
		inv[i] = 1LL * (mod - mod / i) * inv[mod % i] % mod;
		ifac[i] = 1LL * ifac[i - 1] * inv[i] % mod;
	}
    memset(A, 0, sizeof(A));
	for(int i = 1; i <= m; i++)
		for(int j = i; j <= n; j += i)
			Add(A[j], mod - inv[j / i]);
	memset(B, 0, sizeof(B)), GetExp(A, B, n + 1);
	memset(A, 0, sizeof(A)), GetInv(B, A, n + 1);
	n=nn,m=mm;
    //.1
    cout<<ksm(m,n)<<endl;
    //
    
    //.2
    int ans=1;
    for(int i=0;i<n;i++) ans=ans*(m-i)%mod;
    cout<<ans<<endl;
    //
    
    //.3
    ans=0;
    for(int i=0;i<=m;i++){
    	if(i&1){
    		int cnt=cc(m,i)*ksm(m-i,n)%mod;
    		ans=(ans-cnt+mod)%mod;
		}else{
			int cnt=cc(m,i)*ksm(m-i,n)%mod;
    		ans=(ans+cnt+mod)%mod;
		}
	}
	cout<<ans<<endl;
    //
    
    //.4
    ans=0;
    for(int i=1;i<=m;i++) ans=(ans+f[i])%mod;
	cout<<ans<<endl; 
    //
    
    //.5
    cout<<(n<=m?1:0)<<endl;
    //
    
    //.6
    cout<<f[m]<<endl;
    //
    
    //.7
    cout<<cc(n+m-1,m-1)<<endl;
    //
    
    //.8
    cout<<cc(m,n)<<endl;
    //
    
    //.9
    cout<<cc(n-1,m-1)<<endl;
    //
    
    //.10
    cout<<A[n]<<endl;
    //
    
    //.11
    cout<<(n<=m?1:0)<<endl;
    //
    
    //.12
    cout<<A[n-m]<<endl;
    //
}
2025/8/3 16:16
加载中...