rt,这里用的是oiwiki里的合并实现平衡维护,在倒数第二个点mle了。
#include<iostream>
#include<cstdio>
#include<cstring>
#define lc(u) (tr[u].lc)
#define rc(u) (tr[u].rc)
#define W(u) (tr[u].leaf)
#define val(u) (tr[u].val)
#define del(u) (pts[++cur]=(u))
#define pd(w1,w2) (min(w1,w2)>=alp*((w1)+(w2)))
//#define pd(w1,w2) ((w1)<=3*(w2)||(w2)<=3*(w1))
using namespace std;
const int N=1e5+10;
const double alp=0.292;
struct Node{int lc,rc,leaf,val;}tr[(N<<1)+20];
static int pts[N];
int tot,cur,n,m,ans,rt;
inline int max(int x,int y){return x>y?x:y;}
inline int min(int x,int y){return x<y?x:y;}
void fastio(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
}
inline int New(int v){
int u=cur?pts[cur--]:(++tot);
tr[u]={0,0,1,v};
return u;
}
inline void pushup(int u){
W(u)=W(lc(u))+W(rc(u));
val(u)=max(val(lc(u)),val(rc(u)));
}
int merge(int x,int y){
if(!x||!y) return x^y;
if(pd(W(x),W(y))){
int u=New(0);
lc(u)=x;rc(u)=y;
return pushup(u),u;
}
if(W(x)>=W(y)){
if(pd(W(lc(x)),W(rc(x))+W(y)))
return rc(x)=merge(rc(x),y),pushup(x),x;
lc(x)=merge(lc(x),lc(rc(x)));
del(rc(x));rc(x)=merge(rc(rc(x)),y);
return pushup(x),x;
}
if(pd(W(x)+W(lc(y)),W(rc(y))))
return lc(y)=merge(x,lc(y)),pushup(y),y;
rc(y)=merge(rc(lc(y)),rc(y));
del(lc(y));lc(y)=merge(x,lc(lc(y)));
return pushup(y),y;
}
inline void maintain(int &u){
if(W(u)==1)return;
if(pd(W(lc(u)),W(rc(u))))return;
int x=lc(u),y=rc(u);
del(u);u=merge(x,y);
}
void insert(int &u,int v){
if(!rt){rt=New(v);return;}
if(W(u)==1){
lc(u)=New(val(u)),rc(u)=New(v);
if(val(u)>v)lc(u)^=rc(u)^=lc(u)^=rc(u);
return pushup(u);
}
if(val(lc(u))>=v)insert(lc(u),v);
else insert(rc(u),v);
pushup(u);maintain(u);
}
void erase(int &u,int v) {
if(!u)return;
if (val(lc(u))>=v){
if(W(lc(u))==1)del(u),u=rc(u);
else erase(lc(u),v),pushup(u),maintain(u);
}
else {
if(W(rc(u))==1)del(u),u=lc(u);
else erase(rc(u),v),pushup(u),maintain(u);
}
}
inline int Rank(int v){
int u=rt,ans=0;
while(W(u)!=1){
if(val(lc(u))>=v)u=lc(u);
else ans+=W(lc(u)),u=rc(u);
}
return ans+1;
}
inline int kth(int k){
int u=rt;
while(W(u)!=1){
if(W(lc(u))>=k)u=lc(u);
else k-=W(lc(u)),u=rc(u);
}
return val(u);
}
inline int pre(int v){
int u=rt;
while(W(u)!=1){
if(val(lc(u))>=v)u=lc(u);
else ans=val(lc(u)),u=rc(u);
}
if(val(u)>=v)return ans;
else return val(u);
}
inline int nxt(int v){
int u=rt;
while(W(u)!=1){
if(val(lc(u))<=v)u=rc(u);
else ans=val(rc(u)),u=lc(u);
}
if(val(u)<=v)return ans;
else return val(u);
}
signed main(){
fastio();
cin>>m;
int ch,x;
for(int i=1;i<=m;i++){
cin>>ch>>x;
if(ch==1)insert(rt,x);
if(ch==2)erase(rt,x);
if(ch==3)cout<<Rank(x)<<"\n";
if(ch==4)cout<<kth(x)<<"\n";
if(ch==5)cout<<pre(x)<<"\n";
if(ch==6)cout<<nxt(x)<<"\n";
}
return 0;
}
感谢各位大佬的帮助。