刚学SBT 8pts 求助
查看原帖
刚学SBT 8pts 求助
469066
zzxLLL楼主2021/10/6 20:26

记录:https://www.luogu.com.cn/record/59327178 过了第一个点 其他MLE WA 代码:

#include<bits/stdc++.h>
using namespace std;
const int M=1e5+10;
struct node{
	int lc,rc,val,num,size;
}tr[M];
int root,cntp;
int newval(int val){
	cntp++;
	tr[cntp].lc=tr[cntp].rc=0;
	tr[cntp].val=val;
	tr[cntp].size=1;
	return cntp;
}
void update(int k){
	tr[k].size=tr[tr[k].lc].size+tr[tr[k].rc].size+1;
}
void left_rotate(int &p){
	int q=tr[p].rc;
	tr[p].rc=tr[q].lc;
	tr[q].lc=p;
	tr[q].size=tr[p].size;
	update(p);
	p=q;
}
void right_rotate(int &p){
	int q=tr[p].lc;
	tr[p].lc=tr[q].rc;
	tr[q].rc=p;
	tr[q].size=tr[p].size;
	update(p);
	p=q;
}
void maintain(int &p,bool flag){
	if(!p) return;
	if(!flag){
		if(tr[tr[tr[p].lc].lc].size>tr[tr[p].rc].size) right_rotate(p);//LL 
		else if(tr[tr[tr[p].lc].rc].size>tr[tr[p].rc].size) left_rotate(tr[p].lc),right_rotate(p);//LR
		else return;
	}else{
		if(tr[tr[tr[p].rc].rc].size>tr[tr[p].lc].size) left_rotate(p);//RR
		else if(tr[tr[tr[p].rc].lc].size>tr[tr[p].lc].size) right_rotate(tr[p].rc),left_rotate(p);//RL
		else return;
	}
	maintain(tr[p].lc,false);
	maintain(tr[p].rc,true);
	maintain(p,false);
	maintain(p,true);
}
void insert(int &p,int val){
	if(!p){
		p=newval(val);
		return;
	}
	tr[p].size++;
	if(tr[p].val>val) insert(tr[p].lc,val);
	else insert(tr[p].rc,val);
	maintain(p,tr[p].val<=val);
}
void remove(int &p,int val){
	if(!p) return;
	tr[p].size--;
	if(tr[p].val==val){
		if(!tr[p].lc||tr[p].rc) p=tr[p].lc+tr[p].rc;
		else{
			int cur=tr[p].rc;
			while(tr[cur].lc) cur=tr[cur].lc;
			tr[p].val=tr[cur].val;
			remove(tr[p].rc,tr[cur].val);
		}
	}else if(val<tr[p].val) remove(tr[p].lc,val);
	else remove(tr[p].rc,val);
}
int getpre(int &p,int q,int val){
	if(!p) return tr[q].val;
	if(tr[p].val<val) return getpre(tr[p].rc,p,val);
	return getpre(tr[p].lc,q,val);
}
int getsuf(int &p,int q,int val){
	if(!p) return tr[q].val;
	if(tr[p].val>val) return getsuf(tr[p].lc,p,val);
	return getpre(tr[p].rc,q,val);
}
int getrank(int &p,int val){
	if(val<tr[p].val) return getrank(tr[p].lc,val);
	else if(val>tr[p].val) return getrank(tr[p].rc,val);
	return tr[tr[p].lc].size+1;
}
int getval(int &p,int k){
	int ls=tr[tr[p].lc].size+1;
	if(ls==k) return tr[p].val;
	else if(ls<k) return getval(tr[p].rc,k-ls);
	return getval(tr[p].lc,k);
}
void print(int p){
	if(tr[p].lc) print(tr[p].lc);
	printf("%d ",tr[p].val);
	if(tr[p].rc) print(tr[p].rc);
}
int n;
int main(){
	scanf("%d",&n);
	for(int i=1,opt,x;i<=n;i++){
		scanf("%d%d",&opt,&x);
		switch(opt){
			case 1:insert(root,x);break;
			case 2:remove(root,x);break;
			case 3:printf("%d\n",getrank(root,x));break;
			case 4:printf("%d\n",getval(root,x));break;
			case 5:printf("%d\n",getpre(root,0,x));break;
			case 6:printf("%d\n",getsuf(root,0,x));break;
			default:break;
		}
		/*
		putchar('\n');
		print(root);
		putchar('\n');
		printf("root:%d\n",root);
		*/
	}
	return 0;
}

跪求大佬帮忙

2021/10/6 20:26
加载中...