记录:https://www.luogu.com.cn/record/59327178 过了第一个点 其他MLE WA 代码:
#include<bits/stdc++.h>
using namespace std;
const int M=1e5+10;
struct node{
int lc,rc,val,num,size;
}tr[M];
int root,cntp;
int newval(int val){
cntp++;
tr[cntp].lc=tr[cntp].rc=0;
tr[cntp].val=val;
tr[cntp].size=1;
return cntp;
}
void update(int k){
tr[k].size=tr[tr[k].lc].size+tr[tr[k].rc].size+1;
}
void left_rotate(int &p){
int q=tr[p].rc;
tr[p].rc=tr[q].lc;
tr[q].lc=p;
tr[q].size=tr[p].size;
update(p);
p=q;
}
void right_rotate(int &p){
int q=tr[p].lc;
tr[p].lc=tr[q].rc;
tr[q].rc=p;
tr[q].size=tr[p].size;
update(p);
p=q;
}
void maintain(int &p,bool flag){
if(!p) return;
if(!flag){
if(tr[tr[tr[p].lc].lc].size>tr[tr[p].rc].size) right_rotate(p);//LL
else if(tr[tr[tr[p].lc].rc].size>tr[tr[p].rc].size) left_rotate(tr[p].lc),right_rotate(p);//LR
else return;
}else{
if(tr[tr[tr[p].rc].rc].size>tr[tr[p].lc].size) left_rotate(p);//RR
else if(tr[tr[tr[p].rc].lc].size>tr[tr[p].lc].size) right_rotate(tr[p].rc),left_rotate(p);//RL
else return;
}
maintain(tr[p].lc,false);
maintain(tr[p].rc,true);
maintain(p,false);
maintain(p,true);
}
void insert(int &p,int val){
if(!p){
p=newval(val);
return;
}
tr[p].size++;
if(tr[p].val>val) insert(tr[p].lc,val);
else insert(tr[p].rc,val);
maintain(p,tr[p].val<=val);
}
void remove(int &p,int val){
if(!p) return;
tr[p].size--;
if(tr[p].val==val){
if(!tr[p].lc||tr[p].rc) p=tr[p].lc+tr[p].rc;
else{
int cur=tr[p].rc;
while(tr[cur].lc) cur=tr[cur].lc;
tr[p].val=tr[cur].val;
remove(tr[p].rc,tr[cur].val);
}
}else if(val<tr[p].val) remove(tr[p].lc,val);
else remove(tr[p].rc,val);
}
int getpre(int &p,int q,int val){
if(!p) return tr[q].val;
if(tr[p].val<val) return getpre(tr[p].rc,p,val);
return getpre(tr[p].lc,q,val);
}
int getsuf(int &p,int q,int val){
if(!p) return tr[q].val;
if(tr[p].val>val) return getsuf(tr[p].lc,p,val);
return getpre(tr[p].rc,q,val);
}
int getrank(int &p,int val){
if(val<tr[p].val) return getrank(tr[p].lc,val);
else if(val>tr[p].val) return getrank(tr[p].rc,val);
return tr[tr[p].lc].size+1;
}
int getval(int &p,int k){
int ls=tr[tr[p].lc].size+1;
if(ls==k) return tr[p].val;
else if(ls<k) return getval(tr[p].rc,k-ls);
return getval(tr[p].lc,k);
}
void print(int p){
if(tr[p].lc) print(tr[p].lc);
printf("%d ",tr[p].val);
if(tr[p].rc) print(tr[p].rc);
}
int n;
int main(){
scanf("%d",&n);
for(int i=1,opt,x;i<=n;i++){
scanf("%d%d",&opt,&x);
switch(opt){
case 1:insert(root,x);break;
case 2:remove(root,x);break;
case 3:printf("%d\n",getrank(root,x));break;
case 4:printf("%d\n",getval(root,x));break;
case 5:printf("%d\n",getpre(root,0,x));break;
case 6:printf("%d\n",getsuf(root,0,x));break;
default:break;
}
/*
putchar('\n');
print(root);
putchar('\n');
printf("root:%d\n",root);
*/
}
return 0;
}
跪求大佬帮忙