58 求调
查看原帖
58 求调
766573
chenbs楼主2025/2/8 15:53

58 求调

#include <bits/stdc++.h>
using namespace std;
mt19937 rnd(time(0));
class BST {
  public:
	struct node {
		int l, r, val, dat, cnt, sz;
	} a[100005];
	int tot, rt, inf = 1 << 30;
	// int find(int val) {
	// 	int p = rt;
	// 	while (p) {
	// 		if (val == a[p].val)
	// 			break;
	// 		p = val < a[p].val ? a[p].l : a[p].r;
	// 	}
	// 	return p;
	// }
	int newnode(int val) {
		a[++tot].val = val;
		a[tot].cnt = a[tot].sz = 1;
		a[tot].dat = rnd();
		return tot;
	}
	void update(int p) { a[p].sz = a[a[p].l].sz + a[a[p].r].sz + a[p].cnt; }
	void build() {
		newnode(-inf), newnode(inf);
		rt = 1, a[1].r = 2;
        update(rt);
	}
	void zig(int &p) {
		int q = a[p].l;
		a[p].l = a[q].r, a[q].r = p;
		p = q;
        update(p), update(a[p].r);
	}
	void zag(int &p) {
		int q = a[p].r;
		a[p].r = a[q].l, a[q].l = p;
		p = q;
        update(p), update(a[p].l);
	}
	void insert(int &p, int val) {
		if (p == 0) {
			p = newnode(val);
			return;
		}
		if (val < a[p].val) {
			insert(a[p].l, val);
			if (a[p].dat < a[a[p].l].dat)
				zig(p);
		} else {
			insert(a[p].r, val);
			if (a[p].dat < a[a[p].r].dat)
				zag(p);
		}
        update(p);
	}
	int getprev(int val) {
		int ans = 1;
		int p = rt;
		while (p) {
			if (val == a[p].val) {
				if (a[p].l) {
					p = a[p].l;
					while (a[p].r)
						p = a[p].r;
					ans = p;
				}
				break;
			}
			if (a[p].val < val && a[p].val > a[ans].val)
				ans = p;
			p = val < a[p].val ? a[p].l : a[p].r;
		}
		return a[ans].val;
	}
	int getnext(int val) {
		int ans = 2;
		int p = rt;
		while (p) {
			if (val == a[p].val) {
				if (a[p].r) {
					p = a[p].r;
					while (a[p].l)
						p = a[p].l;
					ans = p;
				}
				break;
			}
			if (a[p].val > val && a[p].val < a[ans].val)
				ans = p;
			p = val < a[p].val ? a[p].l : a[p].r;
		}
		return a[ans].val;
	}
	void remove(int &p, int val) {
        if(p==0) return;
		if (val == a[p].val) {
			if (a[p].cnt > 1) {
				a[p].cnt--, update(p);
				return;
			}
			if (a[p].l || a[p].r) {
				if (a[p].r == 0 || a[a[p].l].dat > a[a[p].r].dat)
					zig(p), remove(a[p].r, val);
				else
					zag(p), remove(a[p].l, val);
				update(p);
			} else
				p = 0;
			return;
		}
		remove(val < a[p].val ? a[p].l : a[p].r, val);
		update(p);
	}
	int getrank(int val) {
		int p = rt, ans = 0;
		while (p) {
			if (val < a[p].val) {
				p = a[p].l;
			} else if (val == a[p].val) {
				ans += a[a[p].l].sz;
				break;
			} else {
				ans += a[a[p].l].sz + a[p].cnt;
				p = a[p].r;
			}
		}
		return ans;
	}
	int getval(int rank) {
		int p = rt;
		while (p) {
			if (rank >= a[a[p].l].sz + a[p].cnt) {
				rank -= a[a[p].l].sz + a[p].cnt;
                p = a[p].r;
			} else if (rank >= a[a[p].l].sz){
				return a[p].val;
			} else {
                p = a[p].l;
            }
		}
	}
} t;
int n;
int main() {
	t.build();
	cin >> n;
	for (int i = 1; i <= n; i++) {
		int op, x;
		cin >> op >> x;
		if (op == 1)
			t.insert(t.rt, x);
		else if (op == 2)
			t.remove(t.rt, x);
		else if (op == 3)
			cout << t.getrank(x) << '\n';
		else if (op == 4)
			cout << t.getval(x) << '\n';
		else if (op == 5)
			cout << t.getprev(x) << '\n';
		else
			cout << t.getnext(x) << '\n';
	}
	// t.insert(t.rt, 123);
	// t.insert(t.rt, 122);
	// t.insert(t.rt, 124);
	// for (int i = 1; i <= 5; i++) {
	// 	printf("%d %d %d\n", t.a[i].l, t.a[i].r, t.a[i].val);
	// }
	// cout << t.getnext(122);
}
2025/2/8 15:53
加载中...