过不了样例的萌新求助
查看原帖
过不了样例的萌新求助
337894
expane楼主2020/6/23 19:03

代码在n>4n>4时结果鬼畜。。。

输入:
4
0 6 0 0
输出:
1 6 18 36
输入:
5
0 6 0 0 0
输出:
1 356800975 655027127 222337742 367223101

保证ln的板子能过洛谷模板

#include <cstdio>
#include <ctime>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cassert>
using namespace std;

#define M 998244353
#define ROOT 3
#define ll long long

#ifndef ONLINE_JUDGE // 珍爱电脑行为(雾)
#define N 800
#else
#define N 800100
#endif

int power(ll x, ll y)
{
	int ans = 1;
	for (; y; y >>= 1, x = x * x % M)
		if (y & 1)
			ans = ans * x % M;
	return ans;
}
int limit(int n) // 求最小的k使得2^k >= n
{
	int ans = 0;
	for (int i = 1; i < n; i <<= 1)
		ans++;
	return ans;
}

int rev_buf[N << 1], *rp[30], *rev_ptr = rev_buf;
void ntt(int *a, int lim, int flag)
{
	if (!rp[lim]) { // 菜鸡的内置init()写法。。。
		int *rev = rp[lim] = rev_buf;
		int n = 1 << lim, tmp = n >> 1;
		for (int i = 0; i < n; i += 2) {
			rev[i] = rev[i >> 1] >> 1;
			rev[i | 1] = rev[i >> 1] >> 1 | tmp;
		}
	}
	int *rev = rp[lim], n = 1 << lim;
	for (int i = 0; i < n; i++)
		if (i < rev[i])
			swap(a[i], a[rev[i]]);
	for (int m = 1, k = 2; m < n; m <<= 1, k <<= 1) {
		int omega = power(ROOT, (M - 1) + flag * (M - 1) / k);
		for (int i = 0; i < n; i += k) {
			ll omg = 1;
			for (int j = i; j < i + m; j++) {
				int t = omg * a[j + m] % M;
				a[j + m] = (M + a[j] - t) % M;
				a[j] = (a[j] + t) % M;
				omg = omg * omega % M;
			}
		}
	}
	if (flag == -1) {
		ll inv = power(n, M - 2);
		for (int i = 0; i < n; i++)
			a[i] = a[i] * inv % M;
	}
}

inline void clear(int *a, int l, int r) // 清空a[l,r)
{
	memset(a + l, 0, (r - l) << 2);
}

struct Poly {
	int f[N], n; // n 为项数
	inline operator int * () {
		return f;
	}
};

int tmp1[N], tmp2[N], tmp3[N];
void mulby(Poly &f, Poly &g) // f *= g
{
	int prod_n = f.n + g.n - 1;
	int lim = limit(prod_n + 1), n = 1 << lim;
	memcpy(tmp2, g, g.n << 2);
	clear(f, f.n, n);
	clear(tmp2, g.n, n);
	ntt(f, lim, 1);
	ntt(tmp2, lim, 1);
	for (int i = 0; i < n; i++)
		f[i] = (ll)f[i] * tmp2[i] % M;
	ntt(f, lim, -1);
	f.n = prod_n;
}
void inv(Poly &f, Poly &g) // f = g ^ -1
{
	int lim = limit(g.n);
	clear(g, g.n, 2 << lim);
	clear(f, 0, 2 << lim);
	f[0] = power(g[0], M - 2);
	for (int m = 1, n = 2, l = 1; l <= lim; m <<= 1, n <<= 1, l++) { // m是F0的位数,n是F的位数,l=limit(n)
		memcpy(tmp3, g, n << 2);
		clear(tmp3, n, n << 1);
		ntt(f, l + 1, 1);
		ntt(tmp3, l + 1, 1);
		for (int i = 0; i < (n << 1); i++)
			f[i] = (ll)f[i] * (M + 2 - (ll)f[i] * tmp3[i] % M) % M;
		ntt(f, l + 1, -1);
		clear(f, n, n << 1);
	}
	f.n = g.n;
}  

Poly tmpln1, tmpln2;
void ln(Poly &f, Poly &g) // f = ln(g)
{
	assert(g[0] == 1);
	f.n = g.n;
	inv(tmpln2, g);
	for (int i = 1; i < g.n; i++)
		tmpln1[i - 1] = (ll)g[i] * i % M;
	tmpln1.n = g.n - 1;
	mulby(tmpln1, tmpln2);
	for (int i = 1; i < f.n; i++)
		f[i] = (ll)tmpln1[i - 1] * power(i, M - 2) % M;
	f[0] = 0;
}

Poly tmpexp;
void exp(Poly &f, Poly &g) // f = exp(g)
{
	assert(g[0] == 0);
	int lim = limit(g.n);
	f[0] = 1;
	for (int m = 1, n = 2, l = 1; l <= lim; m <<= 1, n <<= 1, l++) { // m是F0的位数,n是F的位数,l=limit(n)
		clear(f, m, n);
		f.n = n;
		ln(tmpexp, f);
		for (int i = 0; i < n; i++)
			tmpexp[i] = (M + g[i] - tmpexp[i]) % M;
		tmpexp[0] = (1 + tmpexp[0]) % M;
		mulby(f, tmpexp);
	}
	f.n = g.n;
}

Poly f, g;
int main()
{
	scanf("%d", &f.n);
	for (int i = 0; i < f.n; i++)
		scanf("%d", f + i);
	exp(g, f);
	for (int i = 0; i < g.n; i++)
		printf("%d ", g[i]);
}
2020/6/23 19:03
加载中...