先放代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
ll read(){
ll x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
void write(ll x){
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
const int N=1e5+10;
int n,op,x;
#define pl a[p].son[0]
#define pr a[p].son[1]
struct Splay{
int root;
int tot,size;
int INF;
int nc[N];
struct Tree{
int fa,son[2];
int val;
int size,cnt;
}a[N],e;
Splay(){
INF=0x7fffffff;
root=tot=0;
e.fa=e.son[0]=e.son[1]=0;
e.val=e.size=e.cnt=0;
for(int i=1;i<N;i++)
a[nc[i]=i]=e;
}
int get_type(int p){
return a[a[p].fa].son[1]==p;
}
void pushup(int p){
a[p].size=a[pl].size+a[pr].size+a[p].cnt;
}
void connect(int p,int fp,int type){
a[p].fa=fp;
a[fp].son[type]=p;
}
void rotate(int p){
int fp=a[p].fa,type_p=get_type(p);
int ffp=a[fp].fa,type_fp=get_type(fp);
connect(a[p].son[type_p^1],fp,type_p);
connect(fp,p,type_p^1);
connect(p,ffp,type_fp);
pushup(fp);
pushup(p);
}
void splay(int p,int to){
//int to=a[t].fa;
int fp;
while(a[p].fa^to)
{
fp=a[p].fa;
if(a[fp].fa==to)
rotate(p);
else if(get_type(p)^get_type(fp))
rotate(p),rotate(p);
else
rotate(fp),rotate(p);
}
if(!to)
root=p;
}
int get_new(int val,int fp,int type){
int p=nc[++tot];
a[p]=e;
a[p].val=val;
a[p].size=a[p].cnt=1;
connect(p,fp,type);
return p;
}
void insert(int val){
int p=root;
if(!size){
size++;
root=get_new(val,0,1);
return;
}
size++;
while(1){
a[p].size++;
if(val==a[p].val){
a[p].cnt++;
break;
}
if(!a[p].son[val>a[p].val]){
p=get_new(val,p,val>a[p].val);
break;
}
p=a[p].son[val>a[p].val];
}
splay(p,0);
}
int get(int val){
int p=root;
while(a[p].val^val&&a[p].son[val>a[p].val])
p=a[p].son[val>a[p].val];
return p;
}
int get_pre(int val){
splay(get(val),0);
if(val>a[root].val)
return a[root].val;
int p=a[root].son[0];
while(pr)
p=pr;
return a[p].val;
}
int get_next(int val){
splay(get(val),0);
if(val<a[root].val)
return a[root].val;
int p=a[root].son[1];
while(pl)
p=pl;
return a[p].val;
}
void remove(int val){
int p=get(val);
if(a[p].val^val)
return;
size--;
splay(p,0);
if(a[p].cnt>1){
a[p].cnt--;
a[p].size--;
return;
}
int del=p;
if(!pl){
a[root=pr].fa=0;
nc[tot--]=del;
return;
}
int tl=pl,tr=pr;
while(a[tl].son[1])
tl=a[tl].son[1];
splay(tl,p);
connect(tr,tl,1);
connect(tl,0,1);
pushup(root=tl);
nc[tot--]=del;
}
int get_rank(int val){
splay(get(get_next(val-1)),0);
return a[a[root].son[0]].size+1;
}
int get_val(int rank){
if(rank>tot)
return INF;
int p=root;
while(1){
if(rank>a[pl].size&&rank<=a[pl].size+a[p].cnt)
break;
if(rank<=a[pl].size)
p=pl;
else{
rank-=a[pl].size+a[p].cnt;
p=pr;
}
}
splay(p,0);
return a[p].val;
}
void dfs(int p){
if(!p)
return;
dfs(pl);
cout<<p<<" "<<a[p].val<<" "<<a[p].son[0]<<" "<<a[p].son[1]<<endl;
dfs(pr);
}
}tree;
int main()
{
n=read();
while(n--){
op=read();x=read();
switch(op){
case 1:tree.insert(x);break;
case 2:tree.remove(x);break;
case 3:write(tree.get_rank(x));putchar('\n');break;
case 4:write(tree.get_val(x));putchar('\n');break;
case 5:write(tree.get_pre(x));putchar('\n');break;
case 6:write(tree.get_next(x));putchar('\n');break;
}
}
return 0;
}
WA了,错在 remove 函数,好像是回收内存出了问题。
void remove(int val){
int p=get(val);
if(a[p].val^val)
return;
size--;
splay(p,0);
if(a[p].cnt>1){
a[p].cnt--;
a[p].size--;
return;
}
int del=p;
if(!pl){
a[root=pr].fa=0;
nc[tot--]=del;
return;
}
int tl=pl,tr=pr;
while(a[tl].son[1])
tl=a[tl].son[1];
splay(tl,p);
connect(tr,tl,1);
connect(tl,0,1);
pushup(root=tl);
nc[tot--]=del;
}
把倒数第二行的 nc[tot--]=del;
去掉就AC了。
求助如何改?