大佬快来看,蒟蒻求助了!(Splay写挂了)
查看原帖
大佬快来看,蒟蒻求助了!(Splay写挂了)
99506
_LHF_楼主2020/5/5 12:07

蒟蒻最近开始学习Splay树,写得还没什么问题,可是一旦交上去——TLE!请问是代码的常数问题,还是代码本身有错。

#include<cstdio>
#define N 200010
using namespace std;
struct splay_tree{
	int ch[2],fa,size,cnt,id;
} tr[N];
int root,len;
void update(int x)
{
	tr[x].size=tr[tr[x].ch[1]].size+tr[tr[x].ch[0]].size+tr[x].cnt;
}
void rotate(int x)
{
	int y=tr[x].fa;
	int z=tr[y].fa,k=(tr[y].ch[1]==x);
	tr[z].ch[tr[z].ch[1]==y]=x;
	tr[x].fa=z;
	tr[y].ch[k]=tr[x].ch[k^1];
	tr[tr[x].ch[k^1]].fa=y;
	tr[x].ch[k^1]=y;
	tr[y].fa=x;
	update(y);
	update(x);
}
void splay(int x,int g)
{
	while(tr[x].fa!=g)
	{
		int y=tr[x].fa;
		int z=tr[y].fa;
		if(z!=g)
		{
			if((tr[z].ch[0]==y)^(tr[y].ch[0]==x)) rotate(x);
			else rotate(y);
		}
		rotate(x);
	}
	if(g==0) root=x;
}
void find(int x)
{
	int u=root;
	if(!u) return;
	while(tr[u].ch[x>tr[u].id]&&tr[u].id!=x)
	{
		u=tr[u].ch[x>tr[u].id];
	}
	splay(u,0);
}
void insert(int x)
{
	int f=0,c=root;
	while(c&&tr[c].id!=x)
	{
		f=c;
		c=tr[c].ch[x>tr[c].id];
	}
	if(c) tr[c].cnt++;
	else
	{
		c=++len;
		if(f) tr[f].ch[x>tr[f].id]=c;
		tr[c].id=x;
		tr[c].cnt=tr[c].size=1;
		tr[c].fa=f;
	}
	splay(c,0);
}
int serial(int x)
{
	find(x);
	return tr[tr[root].ch[0]].size+1;
}
int rank(int x)
{
	int u=root,v;
	if(x>len) return 0;
	while(true)
	{
		v=tr[u].ch[0];
		if(tr[v].size>=x)
		{
			u=v;
		}
		else if(tr[u].cnt+tr[v].size<x)
		{
			x-=tr[u].cnt+tr[v].size;
			u=tr[u].ch[1];
		}
		else break;
	}
	splay(u,0);
	return tr[u].id;
}
int pre(int x)
{
	find(x);
	if(tr[root].id<x) return root; 
	int u=tr[root].ch[0];
	while(tr[u].ch[1])
	{
		u=tr[u].ch[1];
	}
	return u;
}
int suc(int x)
{
	find(x);
	if(tr[root].id>x) return root; 
	int u=tr[root].ch[1];
	while(tr[u].ch[0])
	{
		u=tr[u].ch[0];
	}
	return u;
}
void del(int x)
{
	int p=pre(x),s=suc(x);
	splay(p,0);
	splay(s,p);
	x=tr[s].ch[0];
	if(tr[x].cnt>1)
	{
		tr[x].cnt--;
		splay(x,0);
	}
	else tr[s].ch[0]=0,update(s);
}
int T,op,x;
int main()
{ 
	scanf("%d",&T);
	while(T--)
	{
		scanf("%d%d",&op,&x);
		if(op==1) insert(x);
		else if(op==2) del(x);
		else if(op==3) printf("%d\n",serial(x));
		else if(op==4) printf("%d\n",rank(x));
		else if(op==5) printf("%d\n",tr[pre(x)].id);
		else if(op==6) printf("%d\n",tr[suc(x)].id);
	}
}

欢迎大佬指点一二。

2020/5/5 12:07
加载中...