各位巨佬,请问为什么下面两段 FFT 和 NTT 的代码在本地会 RE,但在洛谷上会 AC?
代码(FFT):
#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
typedef struct {
double real;
double imaginary;
} Complex;
const int N = 4e6 + 7;
typedef struct {
int n;
Complex a[N];
} Polynomial;
const double pi = acos(-1.0);
int r[N];
inline Complex new_complex(double real, double imaginary){
Complex ans;
ans.real = real;
ans.imaginary = imaginary;
return ans;
}
Complex operator +(const Complex a, const Complex b){
return new_complex(a.real + b.real, a.imaginary + b.imaginary);
}
Complex operator -(const Complex a, const Complex b){
return new_complex(a.real - b.real, a.imaginary - b.imaginary);
}
Complex operator *(const Complex a, const Complex b){
return new_complex(a.real * b.real - a.imaginary * b.imaginary, a.real * b.imaginary + a.imaginary * b.real);
}
Complex operator *=(Complex &a, const Complex b){
return a = a * b;
}
inline void FFT(Polynomial &a, int r[], int limit, int type){
for (register int i = 0; i < limit; i++){
if (i < r[i]) swap(a.a[i], a.a[r[i]]);
}
for (register int i = 1; i < limit; i <<= 1){
int t = i << 1;
Complex wn = new_complex(cos(pi / i), type * sin(pi / i));
for (register int j = 0; j < limit; j += t){
Complex w = new_complex(1.0, 0.0);
for (register int k = 0; k < i; k++){
Complex x = a.a[j + k], y = w * a.a[i + j + k];
a.a[j + k] = x + y;
a.a[i + j + k] = x - y;
w *= wn;
}
}
}
}
Polynomial quick_mul(Polynomial a, Polynomial b, int &limit){
int l = -1, t = a.n + b.n;
Polynomial ans;
while (limit <= t){
limit <<= 1;
l++;
}
for (register int i = 0; i < limit; i++){
r[i] = (r[i >> 1] >> 1) | ((i & 1) << l);
}
FFT(a, r, limit, 1);
FFT(b, r, limit, 1);
ans.n = t;
for (register int i = 0; i <= limit; i++){
ans.a[i] = a.a[i] * b.a[i];
}
FFT(ans, r, limit, -1);
return ans;
}
inline int read(){
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9'){
if (ch == '-') f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9'){
x = x * 10 + (ch ^ 48);
ch = getchar();
}
return x * f;
}
void write(int n){
if (n >= 10) write(n / 10);
putchar(n % 10 + '0');
}
int main(){
int n = read(), m = read(), limit = 1, t = n + m;
Polynomial a, b, ans;
a.n = n;
b.n = m;
for (register int i = 0; i <= n; i++){
a.a[i].real = read();
}
for (register int i = 0; i <= m; i++){
b.a[i].real = read();
}
ans = quick_mul(a, b, limit);
for (register int i = 0; i <= t; i++){
write(ans.a[i].real / limit + 0.5);
putchar(' ');
}
return 0;
}
代码(NTT):
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int N = 4e6 + 7;
typedef struct {
int n;
ll a[N];
} Polynomial;
const int mod = 998244353, mod_g = 3;
int r[N];
inline ll quick_pow(ll x, ll p, ll mod){
ll ans = 1;
while (p){
if (p & 1) ans = ans * x % mod;
x = x * x % mod;
p >>= 1;
}
return ans;
}
inline void NTT(Polynomial &a, int r[], int limit, /*int mod, int mod_g, */int type){
ll mod_g_inv;
for (register int i = 0; i < limit; i++){
if (i < r[i]) swap(a.a[i], a.a[r[i]]);
}
if (type == -1){
mod_g_inv = quick_pow(mod_g, mod - 2, mod);
}
for (register int i = 1; i < limit; i <<= 1){
int t = i << 1;
ll wn = quick_pow(type == 1 ? mod_g : mod_g_inv, (mod - 1) / t, mod);
for (register int j = 0; j < limit; j += t){
ll w = 1;
for (register int k = 0; k < i; k++){
ll x = a.a[j + k], y = w * a.a[i + j + k] % mod;
a.a[j + k] = (x + y) % mod;
a.a[i + j + k] = ((x - y) % mod + mod) % mod;
w = w * wn % mod;
}
}
}
if (type == -1){
ll limit_inv = quick_pow(limit, mod - 2, mod);
for (int i = 0; i < limit; i++){
a.a[i] = a.a[i] * limit_inv % mod;
}
}
}
Polynomial quick_mul(Polynomial a, Polynomial b, int &limit/*, int mod, int mod_g*/){
int l = -1, t = a.n + b.n;
Polynomial ans;
while (limit <= t){
limit <<= 1;
l++;
}
for (register int i = 0; i < limit; i++){
r[i] = (r[i >> 1] >> 1) | ((i & 1) << l);
}
NTT(a, r, limit, /*mod, mod_g, */1);
NTT(b, r, limit, /*mod, mod_g, */1);
ans.n = t;
for (register int i = 0; i <= limit; i++){
ans.a[i] = a.a[i] * b.a[i] % mod;
}
NTT(ans, r, limit, /*mod, mod_g, */-1);
return ans;
}
inline int read(){
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9'){
if (ch == '-') f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9'){
x = x * 10 + (ch ^ 48);
ch = getchar();
}
return x * f;
}
void write(int n){
if (n >= 10) write(n / 10);
putchar(n % 10 + '0');
}
int main(){
int n = read(), m = read(), limit = 1, t = n + m;
Polynomial a, b, ans;
a.n = n;
b.n = m;
for (register int i = 0; i <= n; i++){
a.a[i] = read();
}
for (register int i = 0; i <= m; i++){
b.a[i] = read();
}
ans = quick_mul(a, b, limit/*, mod, mod_g*/);
for (register int i = 0; i <= t; i++){
write(ans.a[i]);
putchar(' ');
}
return 0;
}