萌新初学Splay求助
查看原帖
萌新初学Splay求助
339846
RuntimeErr楼主2021/1/9 11:36

直角三角形,只有60分,WA了测试点6~10

#include<cstdio>
const int N=1e5+10;
#define ident(x,fa) (spl[fa].ch[1]==x)
#define ls(p) spl[p].ch[0]
#define rs(p) spl[p].ch[1]
#define Fa(p) spl[p].fa

struct Splay{
    int ch[2],fa;
    int cnt,size,val;
}spl[N];
int n,cnt,rt;

inline void update(int p){
    spl[p].size=spl[ls(p)].size+spl[rs(p)].size+spl[p].cnt;
}
inline void connect(int p,int fa,int s){
    spl[fa].ch[s]=p;
    Fa(p)=fa;
}
inline void newnode(int &now,int fa,int val){
    spl[now=++cnt].val=val;
    Fa(now)=fa;
    spl[now].size=spl[now].cnt=1;
}
inline void rotate(int x){
    int f=Fa(x),ff=Fa(f),k=ident(x,f);
    connect(spl[x].ch[k^1],f,k);
    connect(x,ff,ident(f,ff));
    connect(f,x,k^1);
    update(f);update(x);
}
inline void splay(int x,int top){
    if(!top)rt=x;
    while(Fa(x)!=top){
        int f=Fa(x),ff=Fa(f);
        if(ff!=top)ident(x,f)^ident(f,ff)?rotate(x):rotate(f);
        rotate(x);
    }
}
inline void delnode(int x){
    splay(x,0);
    if(spl[x].cnt>1)spl[x].cnt--;
    else {
        if(rs(x)){
            int p=rs(x);
            while(ls(p))p=ls(p);
            splay(p,x);
            connect(ls(x),p,0);
            rt=p,Fa(p)=0;
            update(rt);
        }else rt=ls(x),Fa(ls(x))=0;
    }
}
void ins(int val,int &now=rt,int fa=0){
    if(!now)newnode(now,fa,val),splay(now,0);
    else if(spl[now].val>val)ins(val,ls(now),now);
    else if(spl[now].val<val)ins(val,rs(now),now);
    else spl[now].cnt++,splay(now,0);
}
void del(int val,int now=rt){
    if(spl[now].val==val)delnode(now);
    else if(spl[now].val>val)del(val,ls(now));
    else del(val,rs(now));
}
inline int rank(int x){
    int now=rt;
    while(now){
        if(x<spl[now].val)now=ls(now);
        else if(x^spl[now].val)now=rs(now);
        else break;
    }
    splay(now,0);return spl[ls(now)].size+1;
}
inline int num(int rk){
    int now=rt;
    while(now){
        if(spl[ls(now)].size+1==rk)break;
        if(spl[ls(now)].size>=rk)now=ls(now);
        else {
            rk-=spl[ls(now)].size+spl[now].cnt;
            now=rs(now);
        }
    }
    return spl[now].val;
}
inline int pre(){
    int now=ls(rt);
    while(rs(now))now=rs(now);
    return spl[now].val;    
}
inline int nxt(){
    int now=rs(rt);
    while(ls(now))now=ls(now);
    return spl[now].val;
}
int main(){
    scanf("%d",&n);
    int op,x;
    while(n--){
        scanf("%d%d",&op,&x);
        if(op==1)ins(x);
        else if(op==2)del(x);
        else if(op==3)printf("%d\n",rank(x));
        else if(op==4)printf("%d\n",num(x));
        else if(op==5)ins(x),printf("%d\n",pre()),del(x);
        else ins(x),printf("%d\n",nxt()),del(x);
    }
}
2021/1/9 11:36
加载中...