蒟蒻求助
  • 板块学术版
  • 楼主Leasier
  • 当前回复11
  • 已保存回复11
  • 发布时间2020/6/6 22:44
  • 上次更新2023/11/7 01:04:30
查看原帖
蒟蒻求助
201007
Leasier楼主2020/6/6 22:44

各位巨佬,请问为什么下面两段 FFT 和 NTT 的代码在本地会 RE,但在洛谷上会 AC?

代码(FFT):

#include <iostream>
#include <cstdio>
#include <cmath>

using namespace std;

typedef struct {
	double real;
	double imaginary;
} Complex;

const int N = 4e6 + 7;

typedef struct {
	int n;
	Complex a[N];
} Polynomial;

const double pi = acos(-1.0);
int r[N];

inline Complex new_complex(double real, double imaginary){
	Complex ans;
	ans.real = real;
	ans.imaginary = imaginary;
	return ans;
}

Complex operator +(const Complex a, const Complex b){
	return new_complex(a.real + b.real, a.imaginary + b.imaginary);
}

Complex operator -(const Complex a, const Complex b){
	return new_complex(a.real - b.real, a.imaginary - b.imaginary);
}

Complex operator *(const Complex a, const Complex b){
	return new_complex(a.real * b.real - a.imaginary * b.imaginary, a.real * b.imaginary + a.imaginary * b.real);
}

Complex operator *=(Complex &a, const Complex b){
	return a = a * b;
}

inline void FFT(Polynomial &a, int r[], int limit, int type){
	for (register int i = 0; i < limit; i++){ 
		if (i < r[i]) swap(a.a[i], a.a[r[i]]);
	}
	for (register int i = 1; i < limit; i <<= 1){
		int t = i << 1;
		Complex wn = new_complex(cos(pi / i), type * sin(pi / i));
		for (register int j = 0; j < limit; j += t){
			Complex w = new_complex(1.0, 0.0);
			for (register int k = 0; k < i; k++){
				Complex x = a.a[j + k], y = w * a.a[i + j + k];
				a.a[j + k] = x + y;
				a.a[i + j + k] = x - y;
				w *= wn;
			}
		}
	}
}

Polynomial quick_mul(Polynomial a, Polynomial b, int &limit){
	int l = -1, t = a.n + b.n;
	Polynomial ans;
	while (limit <= t){
		limit <<= 1;
		l++;
	}
	for (register int i = 0; i < limit; i++){
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << l);
	}
	FFT(a, r, limit, 1);
	FFT(b, r, limit, 1);
	ans.n = t;
	for (register int i = 0; i <= limit; i++){
		ans.a[i] = a.a[i] * b.a[i];
	}
	FFT(ans, r, limit, -1);
	return ans;
}

inline int read(){
	int x = 0, f = 1;
	char ch = getchar();
	while (ch < '0' || ch > '9'){
		if (ch == '-') f = -1;
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9'){
		x = x * 10 + (ch ^ 48);
		ch = getchar();
	}
	return x * f;
}

void write(int n){
	if (n >= 10) write(n / 10);
	putchar(n % 10 + '0');
}

int main(){
	int n = read(), m = read(), limit = 1, t = n + m;
	Polynomial a, b, ans;
	a.n = n;
	b.n = m;
	for (register int i = 0; i <= n; i++){
		a.a[i].real = read();
	}
	for (register int i = 0; i <= m; i++){
		b.a[i].real = read();
	}
	ans = quick_mul(a, b, limit);
	for (register int i = 0; i <= t; i++){
		write(ans.a[i].real / limit + 0.5);
		putchar(' ');
	}
	return 0;
}

代码(NTT):

#include <iostream>
#include <cstdio>

using namespace std;

typedef long long ll;

const int N = 4e6 + 7;

typedef struct {
	int n;
	ll a[N];
} Polynomial;

const int mod = 998244353, mod_g = 3;
int r[N];

inline ll quick_pow(ll x, ll p, ll mod){
	ll ans = 1;
	while (p){
		if (p & 1) ans = ans * x % mod;
		x = x * x % mod;
		p >>= 1;
	}
	return ans;
}

inline void NTT(Polynomial &a, int r[], int limit, /*int mod, int mod_g, */int type){
	ll mod_g_inv;
	for (register int i = 0; i < limit; i++){
		if (i < r[i]) swap(a.a[i], a.a[r[i]]);
	}
	if (type == -1){
		mod_g_inv = quick_pow(mod_g, mod - 2, mod);
	}
	for (register int i = 1; i < limit; i <<= 1){
		int t = i << 1;
		ll wn = quick_pow(type == 1 ? mod_g : mod_g_inv, (mod - 1) / t, mod);
		for (register int j = 0; j < limit; j += t){
			ll w = 1;
			for (register int k = 0; k < i; k++){
				ll x = a.a[j + k], y = w * a.a[i + j + k] % mod;
				a.a[j + k] = (x + y) % mod;
				a.a[i + j + k] = ((x - y) % mod + mod) % mod;
				w = w * wn % mod;
			}
		}
	}
	if (type == -1){
		ll limit_inv = quick_pow(limit, mod - 2, mod);
		for (int i = 0; i < limit; i++){
			a.a[i] = a.a[i] * limit_inv % mod;
		}
	}
}

Polynomial quick_mul(Polynomial a, Polynomial b, int &limit/*, int mod, int mod_g*/){
	int l = -1, t = a.n + b.n;
	Polynomial ans;
	while (limit <= t){
		limit <<= 1;
		l++;
	}
	for (register int i = 0; i < limit; i++){
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << l);
	}
	NTT(a, r, limit, /*mod, mod_g, */1);
	NTT(b, r, limit, /*mod, mod_g, */1);
	ans.n = t;
	for (register int i = 0; i <= limit; i++){
		ans.a[i] = a.a[i] * b.a[i] % mod;
	}
	NTT(ans, r, limit, /*mod, mod_g, */-1);
	return ans;
}

inline int read(){
	int x = 0, f = 1;
	char ch = getchar();
	while (ch < '0' || ch > '9'){
		if (ch == '-') f = -1;
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9'){
		x = x * 10 + (ch ^ 48);
		ch = getchar();
	}
	return x * f;
}

void write(int n){
	if (n >= 10) write(n / 10);
	putchar(n % 10 + '0');
}

int main(){
	int n = read(), m = read(), limit = 1, t = n + m;
	Polynomial a, b, ans;
	a.n = n;
	b.n = m;
	for (register int i = 0; i <= n; i++){
		a.a[i] = read();
	}
	for (register int i = 0; i <= m; i++){
		b.a[i] = read();
	}
	ans = quick_mul(a, b, limit/*, mod, mod_g*/);
	for (register int i = 0; i <= t; i++){
		write(ans.a[i]);
		putchar(' ');
	}
	return 0;
}
2020/6/6 22:44
加载中...