萌新刚学splay 88pts T一个点
查看原帖
萌新刚学splay 88pts T一个点
125901
FxorG楼主2021/4/3 20:48

RT

#include <cstdio>
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cmath>

#define N (int)(1e5+5)
#define inf (int)(1e9+7)
#define ls(xx) t[xx].ch[0]
#define rs(xx) t[xx].ch[1]
#define fa(xx) t[xx].fa
#define root t[0].ch[1]
#define ED puts("")

using namespace std;

int rd() {
	int f=1,sum=0; char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
	while(isdigit(ch)) {sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
	return sum*f;
}

void pr(int x) {
	if(x<0) {putchar('-');x=-x;}
	if(x>9) pr(x/10);
	putchar(x%10+'0');
}

struct node {
	int fa,ch[2],val,cnt,sum;
}t[N];

int tot=0;

void update(int x) {
	t[x].sum=t[ls(x)].sum+t[rs(x)].sum+t[x].cnt;
}

int ident(int x) {
	return t[fa(x)].ch[0]==x?0:1;	
}

void connect(int x,int fa,int how) {
	t[fa].ch[how]=x; t[x].fa=fa;
}

void rotate(int x) {
	int Y=fa(x),R=fa(Y);
	int Yson=ident(x),Rson=ident(Y);
	connect(t[x].ch[Yson^1],Y,Yson);
	connect(Y,x,Yson^1);
	connect(x,R,Rson);
	update(Y); update(x);
}

void splay(int x,int to) {
	to=fa(to);
	while(fa(x)!=to) {
		int y=fa(x);
		if(t[y].fa==to) rotate(x);
		else if(ident(x)==ident(y)) rotate(y),rotate(x);
		else rotate(x),rotate(x);
	}
}

int newnode(int v,int faa) {
	t[++tot].fa=faa;
	t[tot].cnt=t[tot].sum=1;
	t[tot].val=v;
	return tot;
}

void Insert(int x) {
	int now=root;
	if(root==0) root=newnode(x,0);
	else {
		while(1) {
			t[now].sum++;
			if(t[now].val==x) {
				t[now].cnt++; splay(now,root);
				return;
			}
			int nex=x<t[now].val?0:1;
			if(!t[now].ch[nex]) {
				int p=newnode(x,now);
				t[now].ch[nex]=p;
				splay(p,root);
				return;
			}
			now=t[now].ch[nex];
		}
	}
}

int fd(int x) {
	int now=root;
	while(1) {
		if(!now) return 0;
		if(t[now].val==x) {
			splay(now,root);
			return now;
		}
		int nex=x<t[now].val?0:1;
		now=t[now].ch[nex];
	}
}

void del(int x) {
	int pos=fd(x);
	if(!pos) return;
	if(t[pos].cnt>1) {
		t[pos].cnt--; t[pos].sum--;
		return;
	}
	else {
		if(!t[pos].ch[0]&&!t[pos].ch[1]) {
			root=0; return;
		} else if(!t[pos].ch[0]) {
			root=t[pos].ch[1]; t[root].fa=0;
			return;
		} else {
			int lfs=t[pos].ch[0];
			while(t[lfs].ch[1]) lfs=t[lfs].ch[1];
			splay(lfs,t[pos].ch[0]);
			connect(t[pos].ch[1],lfs,1);
			connect(lfs,0,1);
			update(lfs);
		}
	}
}

int rk(int x) {
	int now=root,ret=0;
	while(1) {
		if(t[now].val==x) return ret+t[t[now].ch[0]].sum+1;
		int nex=x<t[now].val?0:1;
		if(nex==1) ret=ret+t[t[now].ch[0]].sum+t[now].cnt;
		now=t[now].ch[nex];
	}
}

int kth(int x) {
	int now=root;
	while(1) {
		int used=t[now].sum-t[t[now].ch[1]].sum;
		if(t[t[now].ch[0]].sum<x&&x<=used) {
			splay(now,root);
			return t[now].val;
		}
		if(x<used) now=t[now].ch[0];
		else now=t[now].ch[1],x-=used;
	}
}

int fpre(int x) {
	int now=root,ret=-inf;
	while(now) {
		if(t[now].val<x) ret=max(ret,t[now].val);
		int nex=x<=t[now].val?0:1;
		now=t[now].ch[nex];
	}
	return ret;
}

int fnex(int x) {
	int now=root,ret=inf;
	while(now) {
		if(t[now].val>x) ret=min(ret,t[now].val);
		int nex=x<t[now].val?0:1;
		now=t[now].ch[nex];
	}
	return ret;
}

signed main() {
	int q=rd();
	int opt,x;
	while(q--) {
		opt=rd(); x=rd();
		if(opt==1) Insert(x);
		else if(opt==2) del(x);
		else if(opt==3) pr(rk(x)),ED;
		else if(opt==4) pr(kth(x)),ED;
		else if(opt==5) pr(fpre(x)),ED;
		else if(opt==6) pr(fnex(x)),ED;
	}
	return 0;
}
2021/4/3 20:48
加载中...