modint 为何会增大代码常数?能否减小其常数使之快于 long long?
  • 板块学术版
  • 楼主SAMSHAWCRAFT
  • 当前回复8
  • 已保存回复8
  • 发布时间2022/1/18 20:50
  • 上次更新2023/10/28 12:01:09
查看原帖
modint 为何会增大代码常数?能否减小其常数使之快于 long long?
496840
SAMSHAWCRAFT楼主2022/1/18 20:50

以前写多项式题都是不管三七二十一直接开 long long,就导致我的多项式板子常数不小,今天自己实现了一个 modint 交到多项式乘法逆板子上,发现比我用 long long 做的慢了 3 倍。又交了一发 -O2modintlong long 快了约 20%。

我写的 modint 开不开 -O2 会有很大速度差距,但是我用 long long 的话开不开 -O2 速度差距不大。

请问有没有卡常大师或者明白原理的大佬能解释一下为什么会有这种现象,以及能否优化 modint 直至不开 -O2 也比long long 快?

为了方便,下面是我的 modint 实现以及用 modint 的 NTT:

#define qaq inline
using ll=long long;
template<int mod=998244353>struct fp{
  int v; static int get_mod(){ return mod; }
  int inv()const{
    int tmp,a=v,b=mod,x=1,y=0;
    while(b) tmp=a/b,a-=tmp*b,std::swap(a,b),x-=tmp*y,std::swap(x,y);
    if(x<0) x+=mod;
    return x;
  }
  qaq fp(ll x=0){ init(x%mod+mod); }
  qaq fp& init(int x){ v=(x<mod?x:x-mod); return *this; }
  qaq fp operator-()const{ return fp()-*this; }
  fp pow(ll t)const{ fp res=1,b=*this; while(t){if(t&1)res*=b;b*=b;t>>=1;} return res; }
  fp unitrt(int l)const{ return pow((mod-1)/l); }
  qaq fp& operator+=(const fp& x){ return init(v+x.v); }
  qaq fp& operator-=(const fp& x){ return init(v-x.v+mod); }
  qaq fp& operator*=(const fp& x){ v=1LL*v*x.v%mod; return *this; }
  qaq fp& operator/=(const fp& x){ v=1LL*v*x.inv()%mod; return *this; }
  qaq fp operator+(const fp& x)const{ return fp(*this)+=x; }
  qaq fp operator-(const fp& x)const{ return fp(*this)-=x; }
  qaq fp operator*(const fp& x)const{ return fp(*this)*=x; }
  qaq fp operator/(const fp& x)const{ return fp(*this)/=x; }
  bool operator==(const fp& x){ return v==x.v; }
  bool operator!=(const fp& x){ return v!=x.v; }
};
using Fp=fp<>;
const Fp g=Fp(3),invg=Fp(g.inv());
void NTT(int limit,Fp *arr,int sign){
  for(int cx=0;cx<limit;++cx)
    if(cx<revid[cx]) std::swap(arr[cx],arr[revid[cx]]);
  for(int l=2;l<=limit;l<<=1){
    Fp Wn=((sign==1)?g:invg).unitrt(l);
    for(int cx=0;cx<limit;cx+=l){
      Fp w=1;
      for(int cy=0;cy<(l>>1);++cy,w*=Wn){
        Fp tmp1=arr[cx+cy],tmp2=arr[cx+cy+(l>>1)]*w;
        arr[cx+cy]=tmp1+tmp2,arr[cx+cy+(l>>1)]=tmp1-tmp2;
      }
    }
  }
  if(sign==-1){
    int invlim=Fp(limit).inv();
    for(int cx=0;cx<limit;++cx)
      arr[cx]*=invlim;
  }
}
2022/1/18 20:50
加载中...