欸虽然错过了发题解,但是突发奇想心血来潮写了一个,还挺好懂的()
思路是将输入的a和b视为二维空间中的点(a,b),构建一个2D树存储这些点,查询所有点并计算它们的和()
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline ll read() {
ll x = 0, flag = 1;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-') flag = -1;
ch = getchar();
}
while (isdigit(ch)) {
x = x * 10 + (ch ^ 48);
ch = getchar();
}
return x * flag;
}
void out(ll x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x < 10)
putchar(x + '0');
else {
out(x / 10);
putchar(x % 10 + '0');
}
}
const int N = 5e5 + 6;
struct Point {
int x[2], w;
} pt[N];
int n, idx, rt;
int buf[N], tot;
struct III {
int l, r;
Point p;
int L[2], R[2];
int sum, sz;
} tr[N];
int add() {
if (!tot) return ++idx;
return buf[tot--];
}
void push_up(int u) {
auto L = tr[tr[u].l], R = tr[tr[u].r];
tr[u].sum = tr[u].p.w + L.sum + R.sum;
tr[u].sz = L.sz + R.sz + 1;
for (int i = 0; i <= 1; i++) {
tr[u].L[i] = min({tr[u].p.x[i], L.L[i], R.L[i]});
tr[u].R[i] = max({tr[u].p.x[i], L.R[i], R.R[i]});
}
}
const double Al = 0.72;
void get_seq(int u, int cnt) {
if (tr[u].l) get_seq(tr[u].l, cnt);
buf[++tot] = u;
pt[tr[tr[u].l].sz + 1 + cnt] = tr[u].p;
if (tr[u].r) get_seq(tr[u].r, cnt + tr[tr[u].l].sz + 1);
}
int rebuild(int l, int r, int k) {
if (l > r) return 0;
int mid = l + r >> 1, u = add();
nth_element(pt + l, pt + mid, pt + r + 1, [&](Point a, Point b) {
return a.x[k] < b.x[k];
});
tr[u].p = pt[mid];
tr[u].l = rebuild(l, mid - 1, k ^ 1);
tr[u].r = rebuild(mid + 1, r, k ^ 1);
push_up(u);
return u;
}
void maintain(int &u, int k) {
if (tr[u].sz * Al < tr[tr[u].l].sz || tr[u].sz * Al < tr[tr[u].r].sz)
get_seq(u, 0), u = rebuild(1, tot, k);
}
void insert(int &u, Point p, int k) {
if (!u) {
u = add();
tr[u].l = tr[u].r = 0;
tr[u].p = p;
push_up(u);
return;
}
if (p.x[k] <= tr[u].p.x[k])
insert(tr[u].l, p, k ^ 1);
else insert(tr[u].r, p, k ^ 1);
push_up(u);
maintain(u, k);
}
bool In(III t, int x1, int y1, int x2, int y2) {
return t.L[0] >= x1 && t.R[0] <= x2 && t.L[1] >= y1 && t.R[1] <= y2;
}
bool In(Point t, int x1, int y1, int x2, int y2) {
return t.x[0] >= x1 && t.x[0] <= x2 && t.x[1] >= y1 && t.x[1] <= y2;
}
bool Out(III t, int x1, int y1, int x2, int y2) {
return t.R[0] < x1 || t.L[0] > x2 || t.R[1] < y1 || t.L[1] > y2;
}
int query(int u, int x1, int y1, int x2, int y2) {
if (In(tr[u], x1, y1, x2, y2)) return tr[u].sum;
if (Out(tr[u], x1, y1, x2, y2)) return 0;
int res = 0;
if (In(tr[u].p, x1, y1, x2, y2)) res += tr[u].p.w;
res += query(tr[u].l, x1, y1, x2, y2);
res += query(tr[u].r, x1, y1, x2, y2);
return res;
}
int main() {
tr[0].L[0] = tr[0].L[1] = N + 5;
tr[0].R[0] = tr[0].R[1] = -1;
int a = read(), b = read();
insert(rt, {a, 0, a}, 0);
insert(rt, {0, b, b}, 0);
int sum = query(rt, -1e9, -1e9, 1e9, 1e9);
out(sum);
return 0;
}