屑求助TLE35分
查看原帖
屑求助TLE35分
157857
ImmortalWatcher楼主2021/10/31 09:52
#include<cstdio>
#include<cmath>
#include<vector>
#include<algorithm>
#define ll long long
#define double long double
#define Mn 250010
using namespace std;
const double Pi=acos(-1);
const double eps=1e-8;
const int G=3;
int iG,in; 
int n,m,mo=998244353,mod;
struct CP
{
	double x,y;
	CP (double xx=0,double yy=0){x=xx,y=yy;}
	CP operator + (const CP o) const {return CP(x+o.x,y+o.y);}
	CP operator - (const CP o) const {return CP(x-o.x,y-o.y);}
	CP operator * (const CP o) const 
	{return CP(x*o.x-y*o.y,x*o.y+y*o.x);}
};
int tr[Mn];
int _inv[Mn];

int ksm(int x,int y,int p)
{
	long long s=1;
	while (y)
	{
		if (y&1) s=s*x%p;
		x=(long long)x*x%p;
		y>>=1;
	}
	return s;
}

struct poly
{
	vector<int> a;
    poly(int x=0)
    {
        if (x) a.push_back(x);
    }
    int size()
    {
        return a.size();
    }
    void resize(int len)
    {
        a.resize(len);
    }
    void shrink()
    {
        for (; !a.empty() &&!a.back(); a.pop_back());
    }
    poly rever()
    {
        poly ret;
        ret.resize(size());
        for (int i=0;i<ret.size();i++)
            ret.a[i]=a[ret.size()-i-1];
        return ret;
    }
	void print()
	{
		for (int i=0;i<size();i++)
			printf("%d",a[i]);
		puts("");
	}
	int operator [] (const int &x) const
	{
		if (x<0||x>=(int)a.size()) return 0;
		return a[x];
	}
	int calculate(int x)
	{
		int n=a.size(),ret=a[n-1];
		for (int i=n-2;i>=0;i--)
			ret=(1ll*ret*x%mo+a[i])%mo;
		return ret;
	}
	poly inte()
	{
		poly ret;
		ret.resize(a.size()+1);
		ret.a[0]=0;
		for (int i=0;i<size();i++)
			ret.a[i+1]=1ll*a[i]*_inv[i+1]%mo;
		return ret;
	}
	poly diff()
	{
		if (a.empty()) return poly();
		poly ret;
		ret.resize(a.size()-1);
		for (int i=0;i<ret.size();i++)
			ret.a[i]=1ll*(i+1)*a[i+1]%mo;
		return ret;
	}
	inline int fastmod(const int &x) {return x>=mo?x-mo:x;}
	void NTT(int flag)
	{
		int n=size();
		static const int G=3,iG=ksm(G,mo-2,mo);
		for (int i=0;i<n;i++)
    		if (i<tr[i]) swap(a[i],a[tr[i]]);
		for(int p=2;p<=n;p<<=1)
		{
			int len=p>>1,tG=ksm(flag?G:iG,(mo-1)/p,mo);
			for(int k=0;k<n;k+=p)
			{
				int buf=1;
				for(int l=k,tt;l<k+len;l++)
				{
					tt=1ll*buf*a[len+l]%mo;
			        a[len+l]=a[l]-tt+mo;
			        if (a[len+l]>=mo) a[len+l]-=mo;
			        a[l]=a[l]+tt;
			        if (a[l]>=mo) a[l]-=mo;
			        buf=1ll*buf*tG%mo;
				}
			}
		}
		if (!flag)
		{
			int in=ksm(n,mo-2,mo);
			for (int i=0;i<n;i++)
				a[i]=1ll*a[i]*in%mo;
		}
	}
	friend inline bool operator == (poly f,poly g)
	{
		f.shrink(),g.shrink();
		if (f.size()!=g.size()) return false;
		for (int i=0;i<f.size();i++)
			if (f[i]!=g[i]) return false;
		return true;
	}
	friend inline bool operator != (poly f,poly g)
	{
		return !(f==g);
	}
	friend inline poly operator + (poly f,poly g)
	{
		f.resize(max(f.size(),g.size()));
		for (int i=0;i<f.size();i++)
			f.a[i]=(f[i]+g[i])%mo;
		return f;
	}
	friend inline poly operator - (poly f,poly g)
	{
		f.resize(max(f.size(),g.size()));
		for (int i=0;i<f.size();i++)
			f.a[i]=(f[i]-g[i]+mo)%mo;
		return f;
	}
	friend inline poly operator * (poly f,poly g)
	{
		if (!f.size()||!g.size()) return poly();
		int n=1;
		for (;n<f.size()+g.size()-1;n<<=1);
		for (int i=0;i<n;i++)
    		tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
		f.resize(n);g.resize(n);
		f.NTT(1),g.NTT(1);
		for (int i=0;i<n;i++)
			f.a[i]=1ll*f[i]*g[i]%mo;
		f.NTT(0),f.shrink();
		return f;
	}
	friend inline poly operator / (poly f,poly g)
	{
		if (f.size()<g.size()) return poly();
		return (f.rever().modx(f.size()-g.size()+1)*g.rever().inv(f.size()-g.size()+1)).modx(f.size()-g.size()+1).rever();
	}
	friend inline poly operator % (poly f,poly g)
	{
		if (f.size()<g.size()) return f;
		poly q=f/g;
		return (f.modx(g.size()-1)-g.modx(g.size()-1)*q.modx(g.size()-1)).modx(g.size()-1);
	}
	poly modx(int n)
	{
		poly g=*this;
		g.resize(min(g.size(),n));
		return g;
	}
	poly inv(int n)
	{
		poly f(ksm(a[0],mo-2,mo));
 		for (int len=1;len<n;len<<=1)
 			f=(f*(poly(2)-f*(*this).modx(len<<1))).modx(len<<1);
		return f.modx(n);
	}
	poly sq(int n)
	{
		poly f(sqrt(a[0]));
		for (int len=1;len<n;len<<=1)
			f=((*this+f*f)*((poly(2)*f).inv(len<<1))).modx(len<<1);
		return f.modx(n);
	}
	poly ln(int n)
	{
		return (diff()*inv(n)).inte().modx(n);
	}
	poly exp(int n)
	{
		poly f(1);
		for (int len=1;len<n;len<<=1)
			f=(f*(poly(1)-f.ln(len<<1)+(*this).modx(len<<1))).modx(len<<1);
		return f.modx(n);
	}
}f,g;

int read()
{
    int s=0;
	char ch;
	while (ch<'0'||ch>'9') ch=getchar();
	while (ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
	return s;
}

void init()
{
    _inv[1]=1;
    for(int i=2;i<=Mn;i++)
        _inv[i]=1ll*(mo-mo/i)*_inv[mo%i]%mo;
}

signed main()
{
	init();
	scanf("%d",&n);
	f.resize(n);g.resize(n);
	for(int i=0;i<n;i++) scanf("%d",&f.a[i]);
 	g=f.sq(n);
	for (int i=0;i<n;i++)
		printf("%d ",g[i]);
  	return 0;
}

实测好像是 ntt 跑的巨慢,也有可能是其他的问题。

求解。

2021/10/31 09:52
加载中...