Splay 16pts求助
查看原帖
Splay 16pts求助
126871
yzh_Error404Error楼主2022/11/22 17:08

WA了两个点,剩下的T了

#include<bits/stdc++.h>
using namespace std;
const int MAXN=4e5+5;
struct node
{
	int s[2],fa,val;
	int siz,cnt;
	inline void init(int _val,int _fa)
	{
		s[0]=s[1]=0;
		fa=_fa,val=_val;
		siz=1,cnt=1;
	}
}t[MAXN];
int rt,tot;
inline void pushup(int p)
{
	t[p].siz=t[t[p].s[0]].siz+t[t[p].s[1]].siz+t[p].cnt;
}
inline void rot(int x)
{
	int y=t[x].fa,z=t[y].fa;
	int k=(t[y].s[1]==x),_k=(t[z].s[1]==y);
	t[z].s[_k]=x,t[x].fa=z;
	t[y].s[k]=t[x].s[k^1],t[t[x].s[k^1]].fa=y;
	t[x].s[k^1]=y,t[y].fa=x;
	pushup(y),pushup(x);
}
inline void splay(int x,int f)
{
//	cout<<x<<" "<<f<<endl;
	while(t[x].fa!=f)
	{
		int y=t[x].fa,z=t[y].fa;
		int k=(t[y].s[1]==x),_k=(t[z].s[1]==y);
		if(z!=f)
		{
			if(k==_k)rot(y);
			else rot(x);
		}
		rot(x);
	}
	if(!f)rt=x;
}
inline void insert(int v)
{
	int x=rt,f=0;
	while(x)
	{
		if(t[x].val==v)
		{
			t[x].cnt++;
			splay(x,0);
			return;
		}
		f=x;
		x=t[x].s[v>t[x].val];
	}
	if(!x)x=++tot,t[x].init(v,f);
	if(f)t[f].s[v>t[f].val]=x;
	splay(x,0);
}
inline void find(int v)
{
	int x=rt;
	while(x)
	{
		if(t[x].val==v)
		{
			splay(x,0);
			return;
		}
		x=t[x].s[v>t[x].val];
	}
}
inline int get_k(int k)
{
	int x=rt;
	while(x)
	{
		if(t[t[x].s[0]].siz>k)x=t[x].s[0];
		else if(t[t[x].s[0]].siz+1==k)return t[x].val;
		else k-=t[t[x].s[0]].siz+1,x=t[x].s[1];
	}
    return x;
}
inline int get_id(int v)
{
	int x=rt;
	while(x)
	{
//		cout<<x<<" "<<t[x].val<<endl;
		if(t[x].val==v)return t[t[x].s[0]].siz;
		x=t[x].s[v>t[x].val];
	}
    return x;
}
inline int pre(int v)
{
	find(v);
	int x=rt;
//	cout<<x<<" "<<t[x].val<<endl;
	x=t[x].s[0];
//	cout<<x<<" "<<t[x].val<<endl;
	while(t[x].s[1])x=t[x].s[1];
	return x;
}
inline int bac(int v)
{
	find(v);
	int x=rt;
	x=t[x].s[1];
	while(t[x].s[0])x=t[x].s[0];
	return x;
}
inline void del(int v)
{
	int pr=pre(v),ba=bac(v);
//	cout<<pr<<" "<<ba<<endl;
	splay(pr,0),splay(ba,pr);
	if(t[t[ba].s[0]].cnt>1)t[t[ba].s[0]].cnt--;
	else t[ba].s[0]=0;
	pushup(ba),pushup(pr);
}
int n;
inline void dfs(int x)
{
	if(t[x].s[0])dfs(t[x].s[0]);
	printf("%d %d %d\n",x,t[x].val,t[x].cnt);
	if(t[x].s[1])dfs(t[x].s[1]);
}
int main()
{
	scanf("%d",&n);
	insert(-2147483647);
	insert(2147483647);
	for(register int i=1;i<=n;i++)
	{
		int op,x;
		scanf("%d%d",&op,&x);
		if(op==1)insert(x);
		if(op==2)del(x);
		if(op==3)printf("%d\n",get_id(x));
		if(op==4)printf("%d\n",get_k(x+1));
		if(op==5)insert(x),printf("%d\n",t[pre(x)].val),del(x);
		if(op==6)insert(x),printf("%d\n",t[bac(x)].val),del(x);
//		dfs(rt);
	}
	return 0;
} 
/*
10
1 106465
4 1
1 317721
1 460929
1 644985
3 644985
1 89851
6 81968
1 492737
5 493598
*/
2022/11/22 17:08
加载中...