蝴蝶变换的时候为啥不能这样判断?
查看原帖
蝴蝶变换的时候为啥不能这样判断?
105820
阿尔托莉雅丶楼主2021/12/29 20:42
#include <iostream>
#include <math.h>
using namespace std;
const int N = 2e6 + 5;   //remember to modify the range of the data!!
const int mod = 1e9 + 7;
const double PI = acos(-1);
typedef long long ll;

int n, m, t;
int limit = 1; //大于等于结果系数个数的最小2的幂
int bitnum = 0; //上述幂的次数
int r[N << 1]; //记录反转二进制后的值,表示


struct complex
{
    double a, b;
    complex(double x = 0, double y = 0) //构造函数
    {
        a = x, b = y;
    }
    complex operator + (complex &y)
    {
        return complex(a + y.a, b + y.b);
    }
    complex operator - (complex &y)
    {
        return complex(a - y.a, b - y.b);
    }
    complex operator * (complex &y)
    {
        return complex(a * y.a - b * y.b, a * y.b + b * y.a);
    }
} f[N << 1], g[N << 1];

void FFT(complex a[], int type) //type = 1 fft, type = -1 ifft
{
    for(int i = 0; i < limit; i++) 
        if(a[i].a != r[i])     //为什么不能这样判断?
            swap(a[i], a[r[i]]);
    for(int mid = 1;mid < limit; mid <<= 1) //从底层往上合并 枚举待合并区间长度的一半
    {
        //最开始是两个长度为1的序列合并,mid = 1;
        complex Wn(cos(PI / mid), type * sin((PI / mid)));

        for(int len = mid << 1, pos = 0; pos < limit; pos += len)
        {
            complex w(1, 0); //幂,一直乘,得到平方,三次方...
            for(int k = 0; k < mid; k++, w = w * Wn)
            {
                complex x = a[pos + k]; //左边部分
                complex y = w * a[pos + mid + k]; //右边部分
                a[pos + k] = x + y; //左边加
                a[pos + mid + k] = x - y; //右边减
            }
        }
    }
    if(type == 1)
        return;
    for(int i = 0; i <= limit; i++)//逆fft最后要除以limit也就是补成了2的
        a[i].a /= limit;       //整数幂的那个N,将点值转换为系数
}

int main(void)
{
    cin >> n >> m;
    for(int i = 0; i <= n; i++)
        cin >> f[i].a;

    for(int i = 0; i <= m; i++)
        cin >> g[i].a;
    // 补成2的整次幂
    while(limit < n + m + 1)
        limit <<= 1, bitnum++;
    for(int i = 0; i < limit; i++) //二进制反转
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bitnum - 1));
    
    FFT(f, 1);
    FFT(g, 1);
    for(int i = 0; i < limit; i++) //点值相乘
        f[i] = f[i] * g[i];
    FFT(f, -1); //逆fft 

    for(int i = 0; i <= n + m; i++)
        cout << int(f[i].a + 0.5) << ' ';//注意要+0.5,精度问题
    return 0;
}
2021/12/29 20:42
加载中...