求助 splay
查看原帖
求助 splay
134476
wyx__楼主2021/2/23 17:14

全都MLE+RE

自己测试了第一个点,输出也没有问题,感觉也没有哪里会MLE的样子,P3369也过了。

代码:

#include<bits/stdc++.h>
using namespace std;
int val[1100005],lc[1100005],rc[1100005],fa[1100005],size[1100005],cnt[1100005],rt,tot,n,m,ans,lastans;
void update(int x){
	size[x]=size[lc[x]]+size[rc[x]]+cnt[x];
}
void rotate(int x) {
	int y=fa[x],z=fa[y];
	int b=(x==lc[y])?rc[x]:lc[x];
	fa[x]=z,fa[y]=x;
	if(b)fa[b]=y;
	if(z)(y==lc[z]?lc[z]:rc[z])=x;
	if(x==lc[y])rc[x]=y,lc[y]=b;
	else lc[x]=y,rc[y]=b;
	update(y),update(x);
}
bool wrt(int x){
	return x==rc[fa[x]];
}
void splay(int x,int target){
	while(fa[x]!=target){
		int y=fa[x],z=fa[y];
		if(z!=target){
			if(wrt(x)==wrt(y))rotate(y);
			else rotate(x);
		}
		rotate(x);
	}
	if(target==0)rt=x;
}
int find(int v){
	int x=rt;
	while(x){
		if(val[x]==v)break;
		if(val[x]<v)x=rc[x];
		else x=lc[x];
	}
	if(x)splay(x,0);
	return x;
}
int insert(int v){
	int x=rt,y=0,dir;
	while(x){
		y=x;
		size[x]++;
		if(v==val[x])break;
		if(val[x]>v)x=lc[x],dir=0;
		else x=rc[x],dir=1;
	}
	if(x)cnt[x]++;
	else{
		x=++tot;
		fa[x]=y,size[x]++,cnt[x]++,val[x]=v;
		if(y)(dir==0?lc[y]:rc[y])=x;
		splay(x,0);
	}	
}
void join(int x,int y){
	fa[x]=fa[y]=0;
	int w=x;
	while(rc[w])w=rc[w];
	splay(w,0);
	rc[w]=y,fa[y]=w;
	update(w);
}
void Delete(int x){
	cnt[x]--,size[x]--;
	if(cnt[x]==0){
		splay(x,0);
		if(!lc[x]||!rc[x])fa[rt=lc[x]+rc[x]]=0;
		else join(lc[x],rc[x]);
		lc[x]=rc[x]=0; 
	}
}
int val_to_rank(int v){
	int x=rt,ans=1;
	while(x){
		if(val[x]==v){
			ans+=size[lc[x]];
			splay(x,0);
			break;
		}
		if(v<val[x])x=lc[x];
		else {
			ans+=size[lc[x]]+cnt[x];
			x=rc[x];
		}
	}
	return ans;
}
int rank_to_val(int v){
	int x=rt;
	while(1){
		int temp1=size[x]-size[rc[x]],temp2=size[lc[x]];
		if(v>temp2&&v<=temp1)break;
		if(v<=temp2)x=lc[x];
		else {
			v-=temp1;
			x=rc[x];
		}
	}
	splay(x,0);
	return val[x];
}
int lower(int v){
	int x=rt,res=-1e9;
	while(x){
		if(val[x]<v&&val[x]>res)res=val[x];
		if(v>val[x])x=rc[x];
		else x=lc[x];
	}
	return res;
}
int upper(int v){
	int x=rt,res=1e9;
	while(x){
		if(val[x]>v&&val[x]<res)res=val[x];
		if(v<val[x])x=lc[x];
		else x=rc[x];
	}
	return res;
}
void debug(){
	for(int i=1;i<=tot;i++)
		cout<<"node:"<<i<<" val:"<<val[i]<<" fa:"<<fa[i]<<" lc:"<<lc[i]<<" rc:"<<rc[i]<<" cnt:"<<cnt[i]<<" size:"<<size[i]<<endl;
}
int main() {
	//freopen("P6136_1.in","r",stdin);
	//freopen("P6136_1.out","w",stdout);
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		int x;
		scanf("%d",&x);
		insert(x);
	}
	for(int i=1;i<=m;i++){
		int opt,x;
		scanf("%d%d",&opt,&x);
		x^=lastans;
		//cout<<opt<<' '<<x<<endl; 
		if(opt==1)insert(x);
		if(opt==2)Delete(find(x));
		if(opt==3)lastans=val_to_rank(x);
		if(opt==4)lastans=rank_to_val(x);
		if(opt==5)lastans=lower(x);
		if(opt==6)lastans=upper(x);
		//if(opt>2)cout<<lastans<<endl;
		if(opt>2)ans^=lastans;
		//debug();
	}
	cout<<ans;
}
2021/2/23 17:14
加载中...