如何优化
查看原帖
如何优化
49093
_sys楼主2020/9/3 18:33

子集卷积 + 快速幂,cf 上跑了 4.13s。不知道如何卡进 3s。

#include <bits/stdc++.h>
using namespace std;

const int Maxn = 1 << 19 | 5;
int n, k, p, maxi, ct[Maxn], f[Maxn], tmp1[20][Maxn], tmp2[20][Maxn], tmp3[20][Maxn];
long long fac[Maxn], inv[Maxn];
long long fast_pow(long long x, long long y)
{
	long long ans = 1, now = x;
	while (y)
	{
		if (y & 1) ans = ans * now % p;
		now = now * now % p;
		y >>= 1;
	}
	return ans;
}
void init(void)
{
	fac[0] = 1;
	for (int i = 1; i <= maxi; i++)
		fac[i] = fac[i - 1] * i % p;
	inv[maxi] = fast_pow(fac[maxi], p - 2);
	for (int i = maxi - 1; i >= 0; i--)
		inv[i] = inv[i + 1] * (i + 1) % p;
}
int lower(int x)
{
	int tmp = 1;
	for (; tmp < x; tmp <<= 1);
	return tmp;
}
void FWT(int now[], int len, bool type = false)
{
	for (int i = 1; i < len; i <<= 1)
		for (int j = 0; j < len; j += (i << 1))
			for (int k = j; k < i + j; k++)
				(now[i + k] += type ? p - now[k] : now[k]) %= p;
}
void multi(void)
{
	for (int j = 0; j <= 19; j++)
	{
		memset(tmp1[j], 0, sizeof(int[maxi + 1]));
		memset(tmp3[j], 0, sizeof(int[maxi + 1]));	
	}
	for (int j = 0; j <= maxi; j++)
		tmp1[ct[j]][j] = f[j];
	for (int j = 0; j <= 19; j++)
		FWT(tmp1[j], maxi + 1);
	for (int q = 0; q <= 19; q++)
		for (int l = 0; q + l <= 19; l++)
			for (int s = 0; s <= maxi; s++)
				tmp3[q + l][s] = (tmp3[q + l][s] + tmp1[q][s] * (long long) tmp2[l][s]) % p;
	for (int j = 0; j <= 19; j++)
		FWT(tmp3[j], maxi + 1, true);
	for (int j = 0; j <= maxi; j++)
		f[j] = tmp3[ct[j]][j];
}
void multi2(void)
{
	static int tmp[20][Maxn];
	for (int i = 0; i <= 19; i++)
		memset(tmp[i], 0, sizeof(int[maxi + 1]));
	for (int q = 0; q <= 19; q++)
		for (int l = 0; q + l <= 19; l++)
			for (int s = 0; s <= maxi; s++)
				tmp[q + l][s] = (tmp[q + l][s] + tmp2[q][s] * (long long) tmp2[l][s]) % p;
	for (int i = 0; i <= 19; i++)
		FWT(tmp[i], maxi + 1, true);
	for (int i = 0; i <= 19; i++)
		memset(tmp2[i], 0, sizeof(int[maxi + 1]));
	for (int i = 0; i <= maxi; i++)
		tmp2[ct[i]][i] = tmp[ct[i]][i];
	for (int i = 0; i <= 19; i++)
		FWT(tmp2[i], maxi + 1);
}
int main()
{
	scanf("%d%d%d", &n, &k, &p);
	if (n & 1)
	{
		printf("%lld", fast_pow(k, n));
		return 0;
	}
	f[0] = 1;
	maxi = lower(n + 1) - 1;
	init();
	for (int i = 0; i <= maxi; i++)
		ct[i] = __builtin_popcount(i);
	for (int i = 0; i <= maxi; i++)
		tmp2[ct[i]][i] = inv[i];
	for (int i = 0; i <= 19; i++)
		FWT(tmp2[i], maxi + 1);
	int res_k = k;
	while (k)
	{
		cerr << k << endl;
		if (k & 1) multi();
		k >>= 1;
		if (k) multi2();
	}
	printf("%lld", (fast_pow(res_k, n) - fac[n] * f[n] % p + p) % p);
	return 0;
}
2020/9/3 18:33
加载中...