找不到哪里没清空了QAQ
查看原帖
找不到哪里没清空了QAQ
114181
ChasingAft楼主2021/3/30 21:54

样例能过,但是 0pts ,全是 WA ,多项式求逆和 ln/exp 都是粘的能过得板子。

#include <bits/stdc++.h>


using namespace std;
typedef long long ll;
const int N = 4e5 + 10;
const int Mod = 998244353;

inline ll exp(ll a,ll b) {
	ll res = 1;
	while(b) {
		if(b & 1) {
			res = res * a % Mod;
		}
		a = a * a % Mod;
		b >>= 1;
	}
	return res;
}
int pre[N];
inline void NTT(ll * f,int id,int maxn) {
	for(int i = 0;i < maxn;i++) {
		if(i < pre[i]) swap(f[i],f[pre[i]]);
	}
	for(int i = 2;i <= maxn;i <<= 1) {
		int m = i / 2;
		ll Omg1,Omg;
		Omg1 = exp((id == 1 ? 3 : exp(3,Mod - 2)),(Mod - 1) / i);
		for(int j = 0;j < maxn;j += i) {
			Omg = 1;
			for(int k = 0;k < m;k++) {
				ll t = Omg * f[j + k + m] % Mod; 
				f[j + k + m] = f[j + k] - t;
				if(f[j + k + m] < 0) f[j + k + m] += Mod;
				f[j + k] = f[j + k] + t;
				if(f[j + k] > Mod) f[j + k] -= Mod;
				Omg = Omg * Omg1 % Mod;
			}
		}
	}
}
ll f[N],g[N],maxn = 1,INV,B[N],C[N],tmp1[N],tmp2[N];
inline void Inv(ll *a,ll *ans,int siz) {
	ans[0] = exp(a[0],Mod - 2);
	for(int l = 2;l <= siz;l <<= 1) {
		for(int i = 0;i < (l >> 1);i++) f[i] = ans[i];
		for(int i = 0;i < l;i++) g[i] = a[i];
		for(int i = 0;i < (l << 1);i++) pre[i] = (pre[i >> 1] >> 1) | (i & 1 ? l : 0);
		NTT(f,1,l << 1);
		for(int i = 0;i < (l << 1);i++) f[i] = f[i] * f[i] % Mod;
		NTT(g,1,l << 1);
		for(int i = 0;i < (l << 1);i++) f[i] = f[i] * g[i] % Mod;
		NTT(f,-1,l << 1);
		INV = exp((l << 1),Mod - 2);
		for(int i = 0;i < l;i++) {
			ans[i] = (ans[i] << 1) - f[i] * INV % Mod;
			if(ans[i] < 0) ans[i] += Mod;
			if(ans[i] > Mod) ans[i] -= Mod;
		}
		for(int i = 0;i < (l << 1);i++) f[i] = g[i] = 0;
	}
}
inline void Dao(ll *a,ll *ans,int siz) {
	for(int i = 0;i < siz - 1;i++) {
		ans[i] = a[i + 1] * (i + 1) % Mod;
	}
	ans[siz - 1] = 0;
}
ll inv[N];
inline void Ji(ll *a,ll *ans,int siz) {
	for(int i = 1;i < siz;i++) {
		ans[i] = a[i - 1] * inv[i] % Mod;
	}
	ans[0] = 0;
}
inline void Ln(ll *a,ll *ans,int siz,int id) {
	int n = 1;
	if(!id)
	while(n < siz) n <<= 1; 
	else n = siz;
	Dao(a,B,siz);
	Inv(a,C,n);
	for(int i = siz;i < n;i++) C[i] = 0;
	NTT(B,1,n << 1);
	NTT(C,1,n << 1);
	for(int i = 0;i < (n << 1);i++) B[i] = B[i] * C[i] % Mod;
	NTT(B,-1,n << 1);
	INV = exp(n << 1,Mod - 2);
	for(int i = 0;i < siz;i++) {
		C[i] = B[i] * INV % Mod;
	} 
	Ji(C,ans,siz);
	for(int i = 0;i < (n << 1);i++) B[i] = C[i] = 0;
}
inline void exp(ll *a,ll *ans,int siz) {
	ans[0] = 1;
	for(int l = 2;l <= siz;l <<= 1) {
		Ln(ans,tmp1,l,1); // 注意这里是 % x^l 意义下的,所以 ln 的范围要取 l 
		for(int i = 0;i < l;i++) {
			tmp1[i] = a[i] - tmp1[i];
			if(i < (l >> 1))tmp2[i] = ans[i];
			if(tmp1[i] < 0) tmp1[i] += Mod;
		}tmp1[0]++;
		for(int i = 0;i < (l << 1);i++) pre[i] = (pre[i >> 1] >> 1) | (i & 1 ? l : 0);
		if(tmp1[0] > Mod) tmp1[0] -= Mod;
		NTT(tmp1,1,l << 1);
		NTT(tmp2,1,l << 1);
		for(int i = 0;i < (l << 1);i++) tmp1[i] = tmp1[i] * tmp2[i] % Mod;
		NTT(tmp1,-1,l << 1);
		INV = exp(l << 1,Mod - 2);
		for(int i = 0;i < l;i++) ans[i] = tmp1[i] * INV % Mod;
		for(int i = 0;i < (l << 1);i++) tmp1[i] = tmp2[i] = 0; 
	}
}
inline int read() {
	int s = 0;
	char ch = getchar();
	while(ch > '9' || ch < '0') {
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9') {
		s = (s << 1) + (s << 3) + (ch ^ 48);
		s %= Mod;
		ch = getchar();
	}
	return s;
}
ll A[N],a[N],Ans[N];
int n,k;
int main() {
	inv[0] = inv[1] = 1;
    for(int i = 2;i < N / 4;i++) inv[i] = (Mod - Mod / i) * inv[Mod % i] % Mod;
	scanf("%d", &n);
	k = read();
	while(maxn <= n) maxn <<= 1;
	for(int i = 0;i < n;i++) {
		cin >> A[i];
	}
	Ln(A,a,n,0);
	for(int i = 0;i < maxn;i++) {
		a[i] = a[i] * k % Mod;
	}
	exp(a,Ans,maxn);
	for(int i = 0;i < n;i++) {
		cout << Ans[i] << ' ';
	}
	return 0;
}
2021/3/30 21:54
加载中...