splay求助
查看原帖
splay求助
369598
邈云汉楼主2021/8/22 15:42
#include<cstdio>
using namespace std;
const int N=1e6+5;
const int INF=2e9;
struct splay_tree{
	int p,son[10],v,cnt,size;
	void init(int _v,int _p)
	{
		v=_v,p=_p;
		size=1,cnt=1;
	}
}tr[N];
int root,idx;
void pushup(int x)
{
	tr[x].size=tr[tr[x].son[0]].size+tr[tr[x].son[1]].size+tr[x].cnt;
}
void rotate(int x)
{
	int y=tr[x].p,z=tr[y].p;
	int k=(tr[y].son[1]==x);
	tr[z].son[tr[z].son[1]==y]=x,tr[x].p=z;
	tr[y].son[k]=tr[x].son[k^1],tr[tr[x].son[k^1]].p=y;
	tr[x].son[k^1]=y,tr[y].p=x;
	pushup(y),pushup(x);
}
void splay(int x,int p)
{
	while(tr[x].p!=p)
	{
		int y=tr[x].p,z=tr[y].p;
		if(z!=p)
		{
			if((tr[y].son[1]==x)^(tr[z].son[1]==y))rotate(x);
			else rotate(y);
		}
		rotate(x);
	}
	if(!p)root=x;
}
void Insert(int x)
{
	int u=root,p=0;
	while(u)
	{
		if(tr[u].v==x)
		{
			tr[u].cnt++;
			pushup(u);
			splay(u,0);
			goto next_;
		}
		p=u,u=tr[u].son[x>tr[u].v];
	}
	u=++idx;
	if(p)tr[p].son[x>tr[p].v]=u;
	tr[u].init(x,p);
	splay(u,0);
	next_:;
}
void Remove(int x)
{
	int u=root,p=0;
	int k;
	while(u)
	{
		if(tr[u].v==x)break;
		p=u,u=tr[u].son[x>tr[u].v];
		k=(x>tr[p].v);
	}
	if(!u)return;
	else if(tr[u].cnt==1)
	{
		if(!tr[u].son[0]&&!tr[u].son[1])tr[p].son[k]=0,tr[u].p=0,pushup(p),splay(p,0);
		else if(!tr[u].son[1])tr[p].son[k]=tr[u].son[0],tr[tr[u].son[0]].p=p,tr[u].p=0,pushup(p),splay(p,0);
		else if(!tr[u].son[0])tr[p].son[k]=tr[u].son[1],tr[tr[u].son[1]].p=p,tr[u].p=0,pushup(p),splay(p,0);
		else
		{
			int l=tr[u].son[0],r=tr[u].son[1];
			while(tr[l].son[1])l=tr[l].son[1];
			while(tr[r].son[0])r=tr[r].son[0];
			splay(l,0),splay(r,root);
			tr[r].son[0]=0,tr[tr[r].son[0]].p=0,pushup(r),pushup(root);
		}
	}
	else tr[u].cnt--,pushup(u),splay(u,0);
}
int get_rank(int p,int x)
{
	if(p==0)return 0;
	if(x==tr[p].v)return tr[tr[p].son[0]].size+1;
	else if(x<tr[p].v)return get_rank(tr[p].son[0],x);
	else return get_rank(tr[p].son[1],x)+tr[tr[p].son[0]].size+tr[p].cnt;
}
int find_rank(int p,int x)
{
	if(p==0)return INF;
	if(tr[tr[p].son[0]].size>=x)return find_rank(tr[p].son[0],x);
	if(tr[tr[p].son[0]].size+tr[p].cnt>=x)return tr[p].v;
	return find_rank(tr[p].son[1],x-tr[tr[p].son[0]].size-tr[p].cnt);
}
int get_pre(int x)
{
	int ans=1;
	int u=root,p=0;
	while(u)
	{
		if(tr[u].v==x)break;
		p=u,u=tr[u].son[x>tr[u].v];
		if(tr[p].v<x&&tr[p].v>tr[ans].v)ans=p;
	}
	if(!u||!tr[u].son[0])return tr[ans].v;
	u=tr[u].son[0];
	while(tr[u].son[1])u=tr[u].son[1];
	return tr[u].v;
}
int get_next(int x)
{
	int ans=2;
	int u=root,p=0;
	while(u)
	{
		if(tr[u].v==x)break;
		p=u,u=tr[u].son[x>tr[u].v];
		if(tr[p].v>x&&tr[p].v<tr[ans].v)ans=p;
	}
	if(!u||!tr[u].son[1])return tr[ans].v;
	u=tr[u].son[1];
	while(tr[u].son[0])u=tr[u].son[0];
	return tr[u].v;
}
int main()
{
	Insert(-INF),Insert(INF);
	int n;
	scanf("%d",&n);
	while(n--)
	{
		int op,x;
		scanf("%d%d",&op,&x);
		if(op==1)Insert(x);
		if(op==2)Remove(x);
		if(op==3)printf("%d\n",get_rank(root,x)-1);
		if(op==4)printf("%d\n",find_rank(root,x+1));
		if(op==5)printf("%d\n",get_pre(x));
		if(op==6)printf("%d\n",get_next(x));
	}
	return 0;
}
2021/8/22 15:42
加载中...