感觉没什么问题,对着书看了一遍,没看出来,求大佬帮忙!
#include <iostream>
#include <cstdio>
#include <cstdlib>
#define re register
#define Google namespace
#define Cookies std
using Google Cookies;
inline int read()
{
re int x = 0, f = 0;
re char c = getchar();
while (c < '0' || c > '9')
{
f |= c == '-';
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = (x << 3) + (x << 1) + (c ^ 48);
c = getchar();
}
return f ? -x : x;
}
const int MAXN = 1e5 + 5;
const int INF = 0x7fffffff;
struct treap
{
int l, r, val, dat, cnt, siz;
}t[MAXN];
int rt, tot;
int add(int val)
{
t[++tot] = treap{0, 0, val, rand(), 1, 1};
return tot;
}
void pushup(int pos)
{
t[pos].siz = t[t[pos].l].siz + t[t[pos].r].siz + t[pos].cnt;
}
void build()
{
add(-INF), add(INF);
rt = 1;
t[1].r = 2;
pushup(rt);
}
void zig(int &pos)
{
int lson = t[pos].l;
t[pos].l = t[lson].r, t[lson].r = pos;
pos = lson;
}
void zag(int &pos)
{
int rson = t[pos].r;
t[pos].r = t[rson].l, t[rson].l = pos;
pos = rson;
}
void insert(int &pos, int x)
{
if (!pos)
{
pos = add(x);
return;
}
if (x == t[pos].val)
{
t[pos].cnt++;
pushup(pos);
return;
}
else if (x < t[pos].val)
{
insert(t[pos].l, x);
if (t[pos].dat < t[t[pos].l].dat) //不满足,右旋
{
zig(pos);
}
}
else
{
insert(t[pos].r, x);
if (t[pos].dat < t[t[pos].r].dat)
{
zag(pos);
}
}
pushup(pos);
}
void remove(int &pos, int x)
{
if (!pos)
{
return;
}
if (x == t[pos].val)
{
if (t[pos].cnt >= 2)
{
t[pos].cnt--; //减少副本数
pushup(pos);
return;
}
else if (t[pos].l || t[pos].r) //向下旋转
{
if (!t[pos].r || t[t[pos].l].dat > t[t[pos].r].dat)
{
zig(pos);
remove(t[pos].r, x);
}
else
{
zag(pos);
remove(t[pos].l, x);
}
pushup(pos);
}
else
{
pos = 0; //叶节点删除
}
return;
}
else if (x < t[pos].val)
{
remove(t[pos].l, x);
}
else
{
remove(t[pos].r, x);
}
pushup(pos);
}
int getrank(int pos, int x)
{
if (!pos)
{
return 0;
}
if (x == t[pos].val)
{
return t[t[pos].l].siz + 1;
}
else if (x < t[pos].val)
{
return getrank(t[pos].l, x);
}
else
{
return getrank(t[pos].r, x) + t[t[pos].l].siz + t[pos].cnt;
}
}
int getval(int pos, int x)
{
if (!pos)
{
return INF;
}
if (x <= t[t[pos].l].siz)
{
return getval(t[pos].l, x);
}
else if (x <= t[t[pos].l].siz + t[pos].cnt)
{
return t[pos].val;
}
else
{
return getval(t[pos].r, x - t[t[pos].l].siz - t[pos].cnt);
}
}
int pre(int x)
{
int ans = 1, pos = rt;
while (pos)
{
if (x == t[pos].val)
{
if (t[pos].l)
{
pos = t[pos].l;
while (t[pos].r)
{
pos = t[pos].r; //左子树往上走
}
ans = pos;
}
break;
}
if (x > t[pos].val && t[pos].val > t[ans].val)
{
ans = pos;
}
if (x < t[pos].val)
{
pos = t[pos].l;
}
else
{
pos = t[pos].r;
}
}
return t[ans].val;
}
int nxt(int x)
{
int ans = 2, pos = rt;
while (pos)
{
if (x == t[pos].val)
{
if (t[pos].r)
{
pos = t[pos].r;
while (t[pos].l)
{
pos = t[pos].l;
}
ans = pos;
}
break;
}
if (x < t[pos].val && t[pos].val < t[ans].val)
{
ans = pos;
}
if (x < t[pos].val)
{
pos = t[pos].l;
}
else
{
pos = t[pos].r;
}
}
return t[ans].val;
}
int main()
{
build();
int n = read();
while (n--)
{
int op = read(), x = read();
switch (op)
{
case 1:
insert(rt, x);
break;
case 2:
remove(rt, x);
break;
case 3:
printf("%d\n", getrank(rt, x) - 1);
break;
case 4:
printf("%d\n", getval(rt, x + 1));
break;
case 5:
printf("%d\n", pre(x));
break;
case 6:
printf("%d\n", nxt(x));
break;
}
}
return 0;
}