萌新刚学splay,求助
查看原帖
萌新刚学splay,求助
373822
Sola_楼主2021/9/28 17:36

样例输出5,点1,4过了,其他WA了

#include<bits/stdc++.h>
using namespace std;
const int N=1100009,INF=0x7fffffff;
int n,m,tmp,rt,id;
long long last,ans;
long long t[N][2];
int fa[N],siz[N],cnt[N],val[N];

int ask_dir(int u){
	return t[fa[u]][1]==u;
}

void connect(int u,int f,int p){
	if(u!=0) fa[u]=f;
	if(f!=0) t[f][p]=u;
}

void update(int u){
	siz[u]=siz[t[u][0]]+siz[t[u][1]]+cnt[u];
}

void rotate(int u){
	int f=fa[u],gf=fa[f];
	int dir=ask_dir(u),fdir=ask_dir(f);
	int ano_son=t[u][!dir];
	connect(ano_son,f,dir);
	connect(u,gf,fdir);
	connect(f,u,!dir);
	update(f);
	update(u);
}

void splay(int u,int end){
	for(int useless;fa[u]!=end;rotate(u))
		if(fa[fa[u]]!=end&&ask_dir(fa[u])==ask_dir(u))
			rotate(fa[u]);
	if(end==0) rt=u;
}

void insert(int x){
	int u=rt;
	if(!rt){
		rt=++id;
		val[id]=x;
		siz[id]=cnt[id]=1;
		return;
	}
	while(val[u]!=x){
		siz[u]++;
		if(x<=val[u]){
			if(t[u][0]==0){
				val[++id]=x;
				connect(id,u,0);
			}
			u=t[u][0];
		}
		else if(x>val[u]){
			if(t[u][1]==0){
				val[++id]=x;
				connect(id,u,1);
			}
			u=t[u][1];
		}
	}
	siz[u]++,cnt[u]++;
	splay(u,0);
}

void ask_sort(int x){
	int u=rt;
	if(!u) return;
	while(t[u][x>val[u]]&&x!=val[u])
		u=t[u][x>val[u]];
	splay(u,0);
}

int ask_pre_nxt(int x,int typ){
	ask_sort(x);
	int u=rt;
	if((val[u]<x&&!typ)||(val[u]>x&&typ)) return u;
	u=t[u][typ];
	while(t[u][typ^1]) u=t[u][typ^1];
	return u;
}

void cut_off(int x){
	int pre=ask_pre_nxt(x,0);
	int nxt=ask_pre_nxt(x,1);
	splay(pre,0);
	splay(nxt,pre);
	int mission=t[nxt][0];
	if(cnt[mission]>1){
		cnt[mission]--;
		splay(mission,0);
	}
	else
		t[nxt][0]=0;
}

int ask_k_val(int x){
	int u=rt;
	if(siz[u]<x){
		while(t[u][1])
			u=t[u][1];
		return u;
	}
	while(1){
		int dir=t[u][0];
		if(x>siz[dir]+cnt[u]){
			x-=siz[dir]+cnt[u];
			u=t[u][1];
		}
		else if(x<=siz[dir])
			u=t[u][0];
		else
			return val[u];
	}
}

int main(){
	ios::sync_with_stdio(false);
	insert(-INF);
	insert(+INF);
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>tmp;
		insert(tmp);
	}
	for(int i=1;i<=m;i++){
		int opt,x;
		cin>>opt>>x;
		x^=last;
		if(opt==1)
			insert(x);
		else if(opt==2)
			cut_off(x);
		else if(opt==3){
			ask_sort(x);
			last=siz[t[rt][0]];
			ans^=last;
		}
		else if(opt==4){
			last=ask_k_val(x+1);
			ans^=last;
		}
		else if(opt==5){
			last=ask_pre_nxt(x,0);
			ans^=val[last];
		}
		else if(opt==6){
			last=ask_pre_nxt(x,1);
			ans^=val[last];
		}
	}
	cout<<ans<<endl;
	return 0;
} 
2021/9/28 17:36
加载中...