子集卷积 + 快速幂,cf 上跑了 4.13s。不知道如何卡进 3s。
#include <bits/stdc++.h>
using namespace std;
const int Maxn = 1 << 19 | 5;
int n, k, p, maxi, ct[Maxn], f[Maxn], tmp1[20][Maxn], tmp2[20][Maxn], tmp3[20][Maxn];
long long fac[Maxn], inv[Maxn];
long long fast_pow(long long x, long long y)
{
long long ans = 1, now = x;
while (y)
{
if (y & 1) ans = ans * now % p;
now = now * now % p;
y >>= 1;
}
return ans;
}
void init(void)
{
fac[0] = 1;
for (int i = 1; i <= maxi; i++)
fac[i] = fac[i - 1] * i % p;
inv[maxi] = fast_pow(fac[maxi], p - 2);
for (int i = maxi - 1; i >= 0; i--)
inv[i] = inv[i + 1] * (i + 1) % p;
}
int lower(int x)
{
int tmp = 1;
for (; tmp < x; tmp <<= 1);
return tmp;
}
void FWT(int now[], int len, bool type = false)
{
for (int i = 1; i < len; i <<= 1)
for (int j = 0; j < len; j += (i << 1))
for (int k = j; k < i + j; k++)
(now[i + k] += type ? p - now[k] : now[k]) %= p;
}
void multi(void)
{
for (int j = 0; j <= 19; j++)
{
memset(tmp1[j], 0, sizeof(int[maxi + 1]));
memset(tmp3[j], 0, sizeof(int[maxi + 1]));
}
for (int j = 0; j <= maxi; j++)
tmp1[ct[j]][j] = f[j];
for (int j = 0; j <= 19; j++)
FWT(tmp1[j], maxi + 1);
for (int q = 0; q <= 19; q++)
for (int l = 0; q + l <= 19; l++)
for (int s = 0; s <= maxi; s++)
tmp3[q + l][s] = (tmp3[q + l][s] + tmp1[q][s] * (long long) tmp2[l][s]) % p;
for (int j = 0; j <= 19; j++)
FWT(tmp3[j], maxi + 1, true);
for (int j = 0; j <= maxi; j++)
f[j] = tmp3[ct[j]][j];
}
void multi2(void)
{
static int tmp[20][Maxn];
for (int i = 0; i <= 19; i++)
memset(tmp[i], 0, sizeof(int[maxi + 1]));
for (int q = 0; q <= 19; q++)
for (int l = 0; q + l <= 19; l++)
for (int s = 0; s <= maxi; s++)
tmp[q + l][s] = (tmp[q + l][s] + tmp2[q][s] * (long long) tmp2[l][s]) % p;
for (int i = 0; i <= 19; i++)
FWT(tmp[i], maxi + 1, true);
for (int i = 0; i <= 19; i++)
memset(tmp2[i], 0, sizeof(int[maxi + 1]));
for (int i = 0; i <= maxi; i++)
tmp2[ct[i]][i] = tmp[ct[i]][i];
for (int i = 0; i <= 19; i++)
FWT(tmp2[i], maxi + 1);
}
int main()
{
scanf("%d%d%d", &n, &k, &p);
if (n & 1)
{
printf("%lld", fast_pow(k, n));
return 0;
}
f[0] = 1;
maxi = lower(n + 1) - 1;
init();
for (int i = 0; i <= maxi; i++)
ct[i] = __builtin_popcount(i);
for (int i = 0; i <= maxi; i++)
tmp2[ct[i]][i] = inv[i];
for (int i = 0; i <= 19; i++)
FWT(tmp2[i], maxi + 1);
int res_k = k;
while (k)
{
cerr << k << endl;
if (k & 1) multi();
k >>= 1;
if (k) multi2();
}
printf("%lld", (fast_pow(res_k, n) - fac[n] * f[n] % p + p) % p);
return 0;
}