insert就出问题了,build好像可以成功运行
缺少什么注释写在评论区里,我后面再加
我太菜了,debug不出来
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
using namespace std;
const int N = 100010, INF = 1e8;
struct Node
{
Node *l, *r;
int cnt, size;
int key, val;
}*root;
int n;
template<typename T> inline void read(T &x) //快读
{
x = 0;
T f = 1;
char ch = getchar();
while (!isdigit(ch))
{
if (ch == '-') f = -1;
ch = getchar();
}
while (isdigit(ch))
{
x = x * 10 + ch - '0';
ch = getchar();
}
}
Node* get_node(int key) //创造初始化节点
{
Node *p = new Node;
p->key = key;
p->val = rand();
p->cnt = p->size = 1;
return p;
}
void pushup(Node *p) //更新父节点信息
{
p->size = p->l->size + p->r->size + p->cnt;
}
void zag(Node *p) //left rotate
{
Node *q = p->r;
p->r = q->l, q->l = p, p = q;
pushup(p->l), pushup(p);
}
void zig(Node *p) //right rotate
{
Node *q = p->l;
p->l = q->r, q->r = p, p = q;
pushup(p->r), pushup(p);
}
void build() //建树并设置哨兵节点
{
root = get_node(-INF);
Node *q = get_node(INF);
root->r = q;
if (q->val > root->val)
zag(root);
pushup(root);
}
void insert(Node *p, int key) //插入
{
if (!p) p = get_node(key); //到了叶子节点,创建一个新节点
else if (p->key == key) p->cnt ++ ; //已经存在,计数器++
else if (p->key > key) //比根节点小,插入左子树,同时维护堆的性质,右旋
{
insert(p->l, key);
if (p->l->val > p->val)
zig(p);
}
else //否则插入右子树,并维护堆的性质
{
insert(p->r, key);
if (p->r->val > p->val)
zag(p);
}
pushup(p); //因为新增节点,重新计算每个节点的属性
}
void remove(Node *p, int key)
{
if (!p) return;
else if (p->key == key)
{
if (p->cnt > 1) p->cnt -- ;
else if (p->l || p->r)
{
if (!p->l || p->r->val > p->l->val)
{
zag(p);
remove(p->l, key);
}
else
{
zig(p);
remove(p->r, key);
}
}
else p = nullptr;
}
else if (p->key > key) remove(p->l, key);
else remove(p->r, key);
pushup(p);
}
int get_rank_by_key(Node *p, int key)
{
if (!p) return INF;
else if (p->key < key) return p->l->size + p->cnt + get_rank_by_key(p->r, key);
else if (p->key == key) return p->l->size + 1;
else return get_rank_by_key(p->l, key);
}
int get_key_by_rank(Node *p, int rank)
{
if (!p) return INF;
else if (p->l->size >= rank) return get_key_by_rank(p->l, rank);
else if (p->l->size + p->cnt >= rank) return p->key;
else return p->l->size + p->cnt + get_key_by_rank(p->r, rank - p->l->size - p->cnt);
}
int get_prev(Node *p, int x) //严格小于x的最大数
{
if (!p) return -INF;
else if (p->key >= x) return get_prev(p->l, x);
else return max(p->key, get_prev(p->r, x));
}
int get_nex(Node *p, int x) //严格大于x的最小数
{
if (!p) return INF;
else if (p->key <= x) return get_nex(p->r, x);
else return min(p->key, get_nex(p->l, x));
}
int main()
{
read(n);
build();
while (n -- )
{
int opt, x;
read(opt), read(x);
switch(opt)
{
case 1:
insert(root, x);
break;
case 2:
remove(root, x);
break;
case 3:
printf("%d\n", get_rank_by_key(root, x) - 1);
break;
case 4:
printf("%d\n", get_key_by_rank(root, x + 1));
break;
case 5:
printf("%d\n", get_prev(root, x));
break;
case 6:
printf("%d\n", get_nex(root, x));
break;
}
}
return 0;
}