逛了一圈题解发现没有佬用KDT写的
  • 板块P1001 A+B Problem
  • 楼主EXR_FAL
  • 当前回复0
  • 已保存回复0
  • 发布时间2025/6/20 10:02
  • 上次更新2025/6/20 10:04:11
查看原帖
逛了一圈题解发现没有佬用KDT写的
714325
EXR_FAL楼主2025/6/20 10:02

欸虽然错过了发题解,但是突发奇想心血来潮写了一个,还挺好懂的()
思路是将输入的a和b视为二维空间中的点(a,b),构建一个2D树存储这些点,查询所有点并计算它们的和()

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
inline ll read() {
    ll x = 0, flag = 1;
    char ch = getchar();
    while (!isdigit(ch)) {
        if (ch == '-') flag = -1;
        ch = getchar();
    }
    while (isdigit(ch)) {
        x = x * 10 + (ch ^ 48);
        ch = getchar();
    }
    return x * flag;
}

void out(ll x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x < 10)
        putchar(x + '0');
    else {
        out(x / 10);
        putchar(x % 10 + '0');
    }
}

const int N = 5e5 + 6;

struct Point {
    int x[2], w; 
} pt[N];

int n, idx, rt;
int buf[N], tot;

struct III { 
    int l, r;
    Point p;
    int L[2], R[2];
    int sum, sz;
} tr[N];

int add() {
    if (!tot) return ++idx;
    return buf[tot--];
}

void push_up(int u) {
    auto L = tr[tr[u].l], R = tr[tr[u].r];
    tr[u].sum = tr[u].p.w + L.sum + R.sum;
    tr[u].sz = L.sz + R.sz + 1;
    for (int i = 0; i <= 1; i++) {
        tr[u].L[i] = min({tr[u].p.x[i], L.L[i], R.L[i]});
        tr[u].R[i] = max({tr[u].p.x[i], L.R[i], R.R[i]});
    }
}

const double Al = 0.72;
 
void get_seq(int u, int cnt) {
    if (tr[u].l) get_seq(tr[u].l, cnt);
    buf[++tot] = u;
    pt[tr[tr[u].l].sz + 1 + cnt] = tr[u].p;
    if (tr[u].r) get_seq(tr[u].r, cnt + tr[tr[u].l].sz + 1);
}

int rebuild(int l, int r, int k) {
    if (l > r) return 0;
    int mid = l + r >> 1, u = add();
    nth_element(pt + l, pt + mid, pt + r + 1, [&](Point a, Point b) {
        return a.x[k] < b.x[k];
    });
    tr[u].p = pt[mid];
    tr[u].l = rebuild(l, mid - 1, k ^ 1);
    tr[u].r = rebuild(mid + 1, r, k ^ 1);
    push_up(u);
    return u;
}

void maintain(int &u, int k) {
    if (tr[u].sz * Al < tr[tr[u].l].sz || tr[u].sz * Al < tr[tr[u].r].sz)
        get_seq(u, 0), u = rebuild(1, tot, k);
}

void insert(int &u, Point p, int k) {
    if (!u) {
        u = add();
        tr[u].l = tr[u].r = 0;
        tr[u].p = p;
        push_up(u);
        return;
    }
    if (p.x[k] <= tr[u].p.x[k]) 
        insert(tr[u].l, p, k ^ 1);
    else insert(tr[u].r, p, k ^ 1);
    push_up(u);
    maintain(u, k);
}

bool In(III t, int x1, int y1, int x2, int y2) {
    return t.L[0] >= x1 && t.R[0] <= x2 && t.L[1] >= y1 && t.R[1] <= y2;
}

bool In(Point t, int x1, int y1, int x2, int y2) {
    return t.x[0] >= x1 && t.x[0] <= x2 && t.x[1] >= y1 && t.x[1] <= y2;
}

bool Out(III t, int x1, int y1, int x2, int y2) {
    return t.R[0] < x1 || t.L[0] > x2 || t.R[1] < y1 || t.L[1] > y2;
}

int query(int u, int x1, int y1, int x2, int y2) {
    if (In(tr[u], x1, y1, x2, y2)) return tr[u].sum;
    if (Out(tr[u], x1, y1, x2, y2)) return 0;
    int res = 0;
    if (In(tr[u].p, x1, y1, x2, y2)) res += tr[u].p.w;
    res += query(tr[u].l, x1, y1, x2, y2);
    res += query(tr[u].r, x1, y1, x2, y2);
    return res;
}

int main() {
    tr[0].L[0] = tr[0].L[1] = N + 5;
    tr[0].R[0] = tr[0].R[1] = -1;

    int a = read(), b = read();
    
    insert(rt, {a, 0, a}, 0);
    insert(rt, {0, b, b}, 0);
    
    int sum = query(rt, -1e9, -1e9, 1e9, 1e9);
    out(sum);
    
    return 0;
}
2025/6/20 10:02
加载中...