splay平衡树求助
  • 板块学术版
  • 楼主一Iris一
  • 当前回复1
  • 已保存回复1
  • 发布时间2021/1/14 21:48
  • 上次更新2023/11/5 04:49:59
查看原帖
splay平衡树求助
307042
一Iris一楼主2021/1/14 21:48
#include<iostream>
#include<cstdio>
#include<cstring>

using namespace std;

#define int long long
#define INF 1<<30

template<typename _T>
inline void read(_T &x)
{
	x=0;char s=getchar();int f=1;
	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
	x*=f;
}

struct node{
	int key;
	int sum;
	int cnt;
	node *ls,*rs,*fa;	
}; 

node mem[233333],*pool = mem,*aux,*rot;
node *New(){return pool++;}

void hb(node *a)
{
	a->sum = a->cnt;
	if(a->ls)
	a->sum += a->ls->sum ;
	if(a->rs) 
	a->sum+= a->rs->sum;
}

void zig(node *x)
{
	node *y = x->fa;x->fa = y->fa;
	y->ls = x->rs;
	if(y->ls) y->ls->fa = y;
	y->fa = x;x->rs = y;
	hb(y);hb(x);
	if(x->fa!=NULL)
	{
		if(x->fa->ls == y)
			x->fa->ls=x;
		else x->fa->rs = x;
	}
}

void zag(node *x)
{
	node *y=x->fa;x->fa=y->fa;
	y->rs = x->ls;
	
	if(y->rs) y->rs->fa = y;
	
	y->fa = x;x->ls = y;
	hb(y);hb(x);
	if(x->fa!=NULL)
	{
		if(x->fa->ls == y)
			x->fa->ls=x;
		else x->fa->rs = x;
	}
}

inline void splay(node *x,node *S)
{
	if(x == S) return ;
	node *op =S->fa;
	while(23333)
	{
		if(x->fa->fa == op)
		{
			if(x->fa->ls == x) zig(x);
			else zag(x);			
			break;
		}
		node *z = x->fa->fa;
		if(x->fa->ls == x)
		{
			if(z->ls == x->fa)
			{
				zig(x->fa);
				zig(x);
			}
			else
			{
				zig(x);
				zag(x);
			}
		}
		else
		{
			if(z->rs == x->fa)
			{
				zag(x->fa);
				zag(x);
			}
			else
			{
				zag(x);
				zig(x);
			}
		}
		if(x->fa == op) 
		{
			break;
		}
	}
}

node *make(int val)
{
	node *s= New();
	s->key = val;
	s->sum = 1;
	s->cnt =1;
	s->rs = s->ls=NULL;
}

inline node *insert(int x,node *op)
{
	node *splay;
	if(op->key==x) op->cnt++;
	if(op->key<x){
		if(!op->rs){op->rs=make(x);op->rs->fa = op;splay = op->rs;}
		else {splay = insert(x,op->rs);}
	}
	if(op->key>x){
		if(!op->ls){op->ls=make(x);op->ls->fa = op;splay = op->ls;}
		else {splay = insert(x,op->ls);}		
	}
	hb(op);
	return splay;
}

inline node *Find(int x,node *op)
{
	if(op->key!=x)
	{
		if(op->key < x) return Find(x,op->rs);
		if(op->key > x) return Find(x,op->ls);
	}
	else return op;
}

inline int qRank(int x)
{
	node *op = Find(x,rot);
	splay(op,rot);
	rot = op;
	if(op->ls) return op->ls->sum+1;
	else return 1; 
}

inline int rankq(int x,node *sp)
{
	int ans=0;
	if(sp->ls) ans = sp->ls->sum;
	else ans = 0;
	
	if(ans+sp->cnt < x){return rankq(x- ans - sp->cnt,sp->rs);}
	else
	{
		if(x<=ans) return rankq(x,sp->ls);
		else return sp->key;
	}
}

inline node *mAx(node *s){return (s->rs)? mAx(s->rs) : s;}
inline node *mIn(node *s){return (s->ls)? mIn(s->ls) : s;}

inline void Delete(int x)
{
	node *op = Find(x,rot);
	splay(op,rot);
	rot = op;
	if(rot->cnt > 1) {rot->cnt--;hb(rot);return;}
	if(!rot->ls){rot = rot->rs;rot->fa = 0;return;}
	node *s = mAx(rot->ls);
	splay(s,rot->ls);
	s->rs = rot->rs;
	if(s->rs)s->rs->fa = s;
	s->fa=0;
	hb(s);
	rot = s;
}

inline int pre(int x)
{
	node op = *rot;
	int ans = -INF;
	while(23333)
	{
		if(op.key < x) 
		{
			ans = max(op.key , ans);
			if(op.rs!=NULL) op = *op.rs;
			else break;
		}
		else 
		{
			if(op.ls!=NULL)op = *op.ls;
			else break;
		}
	}
	return ans;	
}

inline int nxt(int x)
{
	node op = *rot;
	int ans = INF;
	while(2333)
	{
		if(op.key > x) 
		{
			ans = min(op.key , ans);
			if(op.ls!=NULL) op = *op.ls;
			else break;
		}
		else 
		{
			if(op.rs!=NULL)op = *op.rs;
			else break;
		}
	}
	return ans;	
}

inline void dfs(node *s)
{
	if(s->ls!=NULL) dfs(s->ls);
	cout<<(s->key)<<" ";
	if(s->rs!=NULL) dfs(s->rs);
}

signed main()
{
//	freopen("a.in","r",stdin);
	
	rot = make(INF);
	rot->sum = rot->cnt = 0;
	insert(-INF,rot);
	rot->ls->sum = rot->ls->cnt = 0;
	hb(rot);
	int n;
	
	read(n);
	
	for(int i=1,op,val;i<=n;i++)
	{
		read(op);
		read(val);
		if(op == 1) 
		{
		//	dfs(rot);printf("\n");
			aux = insert(val,rot);//dfs(rot);printf("\n");
			splay(aux,rot);
			rot = aux;	
		}
		if(op == 2)	Delete(val);
		if(op == 3) printf("%lld\n",qRank(val));
		if(op == 4) printf("%lld\n",rankq(val,rot));//cout<<(qrank(val))<<'\n';
		if(op == 5) printf("%lld\n",pre(val));
		if(op == 6) printf("%lld\n",nxt(val));
		
		
	}	
}

想知道为什么第6.7.8.9.10为什么RE 求助,谢谢

2021/1/14 21:48
加载中...