treap求调
查看原帖
treap求调
942161
Tomle楼主2024/9/17 20:24
#include <bits/stdc++.h>
#define il inline
#define re register
#define ri re int
#define fr(a, b, c) for (ri a = (b); a <= (c); ++a)
#define ur(a, b, c) for (ri a = (b); a >= (c); --a)
#define inf (0x3f3f3f3f)
using namespace std;

struct node {
    int l, r, val, cnt, dat, siz;
    #define l(p) t[p].l
    #define r(p) t[p].r
    #define val(p) t[p].val
    #define cnt(p) t[p].cnt
    #define dat(p) t[p].dat
    #define siz(p) t[p].siz
} t[1100005];

int n, m, tot, a[100005], rt, last, op, x, ans;

il void read(ri &a, ri ch = 0) {
    while (!isdigit(ch = getchar()));
    for (a = 0; isdigit(ch); ch = getchar()) a = (a << 3) + (a << 1) + (ch ^ 48);
}
il int New(ri val) {
    val(++tot) = val, dat(tot) = rand(), cnt(tot) = siz(tot) = 1;
    return tot;
}
il void pushup(ri p) {
    siz(p) = siz(l(p)) + siz(r(p)) + cnt(p);
}
il void zig(ri &p) {
    ri q = l(p);
    l(p) = r(q), r(q) = p, p = q;
    pushup(r(p)), pushup(p);
}
il void zag(ri &p) {
    ri q = r(p);
    r(p) = l(q), l(q) = p, p = q;
    pushup(l(p)), pushup(p);
}
void insert(ri &p, ri val) {
    if (!p) return void(p = New(val));
    if (val(p) == val) return ++cnt(p), pushup(p);
    if (val < val(p)) {
        insert(l(p), val);
        if (dat(l(p)) > dat(p)) zig(p);
    } else {
        insert(r(p), val);
        if (dat(r(p)) > dat(p)) zag(p);
    }
    pushup(p);
}
void remove(ri &p, ri val) {
    if (!p) return;
    if (val(p) == val) {
        if (cnt(p) > 1) return --cnt(p), pushup(p);
        if (l(p) || r(p)) {
            if (!r(p) || dat(l(p)) > dat(r(p))) zig(p), remove(r(p), val);
            else zag(p), remove(l(p), val);
            pushup(p);
        } else return void(p = 0);
    }
    if (val < val(p)) remove(l(p), val);
    else remove(r(p), val);
    pushup(p);
}
int get_rk(ri p, ri val) {
    if (!p) return 0;
    if (val(p) == val) return siz(l(p)) + 1;
    if (val < val(p)) return get_rk(l(p), val);
    else return siz(l(p)) + cnt(p) + get_rk(r(p), val);
}
int kth(ri p, ri rk) {
    if (siz(l(p)) >= rk) return kth(l(p), rk);
    else if (siz(l(p)) + cnt(p) >= rk) return val(p);
    else return kth(r(p), rk - siz(l(p)) - cnt(p));
}
int get_pre(ri val) {
    ri ans = 1, p = rt;
    while (p) {
        if (val(p) == val) {
            if (l(p)) for (ans = l(p); r(ans); ans = r(ans));
            break;
        }
        if (val(p) < val && val(p) > val(ans)) ans = p;
        if (val < val(p)) p = l(p);
        else p = r(p);
    }
    return val(ans);
}
int get_nxt(ri val) {
    ri ans = 2, p = rt;
    while (p) {
        if (val(p) == val) {
            if (r(p)) for (ans = r(p); l(ans); ans = l(ans));
            break;
        }
        if (val(p) > val && val(p) < val(ans)) ans = p;
        if (val < val(p)) p = l(p);
        else p = r(p);
    }
    return val(ans);
}
void build() {
    New(-inf), New(inf);
    r(rt = 1) = 2;
    pushup(1);
    fr(i, 1, n) insert(rt, a[i]);
}
int main() {
    srand(time(0));
    read(n), read(m);
    fr (i, 1, n) read(a[i]);
    build();
    while (m--) {
        read(op), read(x);
        x ^= last;
        switch (op) {
            case 1: { insert(rt, x); break; }
            case 2: { remove(rt, x); break; }
            case 3: { ans ^= (last = get_rk(rt, x) - 1); break; }
            case 4: { ans ^= (last = kth(rt, x + 1)); break; }
            case 5: { ans ^= (last = get_pre(x)); break; }
            case 6: { ans ^= (last = get_nxt(x)); break; }
        }
    }
    cout << ans;
    return 0;
}
2024/9/17 20:24
加载中...