#include <bits/stdc++.h>
using namespace std;
inline void read(int& x)
{
x = 0;
int f = 1;
char ch = getchar();
while (ch < 48 || ch > 57)
{
if (ch == 45) f = -1;
ch = getchar();
}
while (ch >= 48 && ch <= 57)
{
x = (x << 3) + (x << 1) + (ch ^ 48);
ch = getchar();
}
x = x * f;
}
inline void write(int x)
{
if (x < 0)
{
x = -x;
putchar(45);
}
if (x > 9) write(x / 10);
putchar(x % 10 + 48);
}
int a[100001], line;
double alpha = 7.5;
int val[100001];
int cnt[100001];
int size[100001];
int size_fact[100001];
int size_delet[100001];
int lc[100001];
int rc[100001];
int num;
int root;
inline void update_size(int o)
{
size[o] = size[lc[o]] + size[rc[o]] + cnt[o];
size_fact[o] = size_fact[lc[o]] + size_fact[rc[o]] + 1;
size_delet[o] = size_delet[lc[o]] + size_delet[rc[o]] + (cnt[o] > 0);
}
inline void flatten(int o)
{
if (!o) return;
flatten(lc[o]);
if (cnt[o]) a[++line] = o;
flatten(rc[o]);
}
void build(int L, int R, int& o)
{
int mid = (L + R) >> 1;
o = a[mid];
lc[o] = rc[o] = 0;
if (L < mid) build(L, mid - 1, lc[o]);
if (R > mid) build(mid + 1, R, rc[o]);
update_size(o);
return;
}
void rebuild(int& o)
{
line = 0;
flatten(o);
build(1, line, o);
}
inline void judge(int& o)
{
if(cnt[o] && (size_fact[o] * alpha <= (double)max(size_fact[lc[o]], size_fact[rc[o]]) ||
size_fact[o] * alpha >= (double)size_delet[o]))
rebuild(o);
}
void insert(int& o, int x)
{
if (!o)
{
o = ++num;
val[o] = x;
cnt[o] = 1;
size[o] = size_fact[o] = size_delet[o] = 1;
return;
}
if (val[o] == x)
++cnt[o];
else if (x < val[o])
insert(lc[o], x);
else if (x > val[o])
insert(rc[o], x);
update_size(o);
judge(o);
return;
}
void delet(int& o, int x)
{
if (val[o] == x)
--cnt[o];
else if (x < val[o])
delet(lc[o], x);
else if (x > val[o])
delet(rc[o], x);
update_size(o);
judge(o);
return;
}
int query_rank(int o, int x)
{
if (val[o] == x) return size[lc[o]] + 1;
if (x < val[o]) return query_rank(lc[o], x);
if (x > val[o]) return size[lc[o]] + cnt[o] + query_rank(rc[o], x);
}
int query_num(int o, int k)
{
if (size[lc[o]] >= k) return query_num(lc[o], k);
if (size[lc[o]] + cnt[o] < k) return query_num(rc[o], k - size[lc[o]] - cnt[o]);
return val[o];
}
inline int query_pre(int x)
{
int o = root;
int res = -2147483647;
while (o)
{
if (val[o] < x)
{
if (cnt[o]) res = val[o];
o = rc[o];
}
else if (val[o] >= x) o = lc[o];
}
return res;
}
inline int query_sub(int x)
{
int o = root;
int res = 2147483647;
while (o)
{
if (val[o] > x)
{
if (cnt[o]) res = val[o];
o = lc[o];
}
else if (val[o] <= x) o = rc[o];
}
return res;
}
int n;
signed main(void)
{
read(n);
while (n--)
{
int op, x;
read(op); read(x);
switch(op)
{
case 1: insert(root, x); break;
case 2: delet(root, x); break;
case 3: write(query_rank(root, x)); putchar('\n'); break;
case 4: write(query_num(root, x)); putchar('\n'); break;
case 5: write(query_pre(x)); putchar('\n'); break;
case 6: write(query_sub(x)); putchar('\n'); break;
}
}
return 0;
}
考虑到了重复插入,所以开了三个 size 数组,一个用来存子节点个数,一个用来存数字个数,一个用来存不算删除过的点的节点个数。其他的,大概是没有问题了罢?求助