#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
inline long long ksm(long long a, long long b)
{
long long r = 1;
while (b)
{
if (b & 1)
{
r *= a;
}
b >>= 1;
a *= a;
}
return r;
}
const int CUTOFF = 8;//2 < CUTOFF < 9
const long long BASE = ksm(10, CUTOFF);
template<typename T>
inline void vec_reserve(vector<T> &v, size_t n)
{
if (v.capacity() < n)
{
v.reserve(n);
}
}
inline vector<long long> str_to_num(const string &s)
{
vector<long long> res;
int len = s.size(), first = len % CUTOFF;
if (first != 0)
{
long long num = 0;
for (register int i = 0; i < first; i++)
{
num = (num << 3) + (num << 1) + (s[i] - '0');
}
res.push_back(num);
}
for (register int i = first; i < len; i += CUTOFF)
{
long long num = 0;
for (register int j = 0; j < CUTOFF; j++)
{
num = (num << 3) + (num << 1) + (s[i + j] - '0');
}
res.push_back(num);
}
reverse(res.begin(), res.end());
return res.empty() ? vector<long long>{0} : res;
}
inline string num_to_str(const vector<long long> &num)
{
if (num.empty())
{
return "0";
}
string res;
res.reserve(num.size() * CUTOFF);
bool leading = true;
for (auto it = num.rbegin(); it != num.rend(); it++)
{
if (leading)
{
if (*it != 0)
{
res += to_string(*it);
leading = false;
}
}
else
{
char buf[10];
snprintf(buf, sizeof(buf), ("%0" + to_string(CUTOFF) + "lld").c_str(), *it);
res.append(buf, CUTOFF);
}
}
return res.empty() ? "0" : res;
}
inline void add_inplace(vector<long long> &a, const vector<long long> &b)
{
long long carry = 0;
int max_size = max(a.size(), b.size());
vec_reserve(a, max_size + 1);
for (register int i = 0; i < max_size || carry; i++)
{
if (i == a.size()) a.push_back(0);
long long bb = i < b.size() ? b[i] : 0;
a[i] += bb + carry;
carry = a[i] / BASE;
a[i] %= BASE;
}
}
inline vector<long long> add(const vector<long long> &a, const vector<long long> &b)
{
vector<long long> res = a;
add_inplace(res, b);
return res;
}
inline void sub_inplace(vector<long long> &a, const vector<long long> &b)
{
long long borrow = 0;
for (register int i = 0; i < a.size(); i++)
{
long long bb = i < b.size() ? b[i] : 0;
a[i] -= borrow + bb;
borrow = a[i] < 0 ? 1 : 0;
if (borrow)
{
a[i] += BASE;
}
}
while (a.size() > 1 && a.back() == 0)
{
a.pop_back();
}
}
inline vector<long long> multiply_simple(const vector<long long> &a, const vector<long long> &b)
{
vector<long long> res(a.size() + b.size(), 0);
for (register int i = 0; i < a.size(); i++)
{
long long carry = 0;
for (register int j = 0; j < b.size(); j++)
{
res[i + j] += a[i] * b[j] + carry;
carry = res[i + j] / BASE;
res[i + j] %= BASE;
}
if (carry)
{
res[i + b.size()] += carry;
}
}
while (res.size() > 1 && res.back() == 0)
{
res.pop_back();
}
return res;
}
inline vector<long long> karatsuba_opt(const vector<long long> &a, const vector<long long> &b, int l, int r)
{
if (r - l <= CUTOFF)
{
return multiply_simple(vector<long long>(a.begin() + l, a.begin() + r), vector<long long>(b.begin() + l, b.begin() + r));
}
int mid = (r - l) >> 1, shift = mid << 1;
vector<long long> lo = karatsuba_opt(a, b, l, l + mid), hi = karatsuba_opt(a, b, l + mid, r), a_sum(mid + 1), b_sum(mid + 1);
for (register int i = 0; i < mid; ++i)
{
a_sum[i] = (l + mid + i < a.size() ? a[l + mid + i] : 0) + a[l + i];
b_sum[i] = (l + mid + i < b.size() ? b[l + mid + i] : 0) + b[l + i];
}
vector<long long> mid_term = karatsuba_opt(a_sum, b_sum, 0, mid);
mid_term.resize(mid_term.size() + shift, 0);
rotate(mid_term.rbegin(), mid_term.rbegin() + shift, mid_term.rend());
vector<long long> res(lo.size() + shift, 0);
copy(lo.begin(), lo.end(), res.begin());
add_inplace(res, hi);
add_inplace(res, mid_term);
sub_inplace(res, lo);
sub_inplace(res, hi);
return res;
}
inline vector<long long> karatsuba(const vector<long long> &a, const vector<long long> &b)
{
int size = max(a.size(), b.size());
vector<long long> a_pad = a, b_pad = b;
a_pad.resize(size, 0);
b_pad.resize(size, 0);
return karatsuba_opt(a_pad, b_pad, 0, size);
}
int main()
{
ios::sync_with_stdio(false), cin.tie(nullptr), cout.tie(nullptr);
string a, b;
cin >> a >> b;
vector<long long> num_a = str_to_num(a), num_b = str_to_num(b), product = karatsuba(num_a, num_b);
cout << num_to_str(product) << '\n';
return 0;
}
时间复杂度 O(w3nlog(3)) (其中 w 为压位位数 8 ),简单计算得 O(1.2×109) ,无可救药了吗?
使用 __int128
的话, w 约为 40 可能勉强 卡不过 !!!