TLE #13(玄关)
查看原帖
TLE #13(玄关)
1036707
chzhh_111楼主2025/2/6 14:16
#include<bits/stdc++.h>
#define int long long
#define lc(x) tr[x].ch[0]
#define rc(x) tr[x].ch[1]
#define fa(x) tr[x].fa
#define val(x) tr[x].val
#define size(x) tr[x].size
#define cnt(x) tr[x].cnt
#define check(x) (rc(fa(x))==x)
#define pushup(x) size(x)=size(lc(x))+size(rc(x))+cnt(x)
using namespace std;
const int N=1e5+100;
int n,tot,root;
struct tree{
	int fa,ch[2],val,cnt,size;
}tr[N];
void rotate(int x)
{
	int y=fa(x),z=fa(y);
	int k=check(x),w=tr[x].ch[k^1];
	if(z) tr[z].ch[check(y)]=x;
	fa(x)=z;
	tr[y].ch[k]=w,fa(w)=y;
	tr[x].ch[k^1]=y,fa(y)=x;
	pushup(x),pushup(y);
}
void splay(int x,int target)
{
	while(fa(x)!=target)
	{
		int y=fa(x),z=fa(y);
		if(z!=target)
		  if(check(x)==check(y)) rotate(y);
		    else rotate(x);
		rotate(x);
	}
	if(!target) root=x;
}
int find1(int x,int v)
{
	int val=val(x);
	if(v<val) return find1(lc(x),v);
	if(v==val) return x;
	if(v>val) return find1(rc(x),v);
}
int find2(int x,int v,int s)
{
	if(!x) return s+1;
	int val=val(x);
	if(v<val) return find2(lc(x),v,s);
	if(v==val) return s+size(lc(x))+1;
	if(v>val) return find2(rc(x),v,s+size(lc(x))+cnt(x));
}
int find3(int x,int v)
{
	int size=size(lc(x)),cnt=cnt(x);
	if(v<=size) return find3(lc(x),v);
	if(v>size&&v<=size+cnt) return x;
	if(v>size+cnt) return find3(rc(x),v-size-cnt);
}
int find4(int x,int v,int node)
{
	if(!x) return node;
	int val=val(x);
	if(v<=val) return find4(lc(x),v,node);
	else return find4(rc(x),v,x);
}
int find5(int x,int v,int node)
{
	if(!x) return node;
	int val=val(x);
	if(v<val) return find5(lc(x),v,x);
	return find5(rc(x),v,node);
}
int insert(int x,int v,int y)
{
	if(!x)
	{
		tr[++tot]=(tree){y,{0,0},v,1,1};
		if(y) tr[y].ch[v>val(y)]=tot;
		return tot;
	}
	int val=val(x);
	size(x)++;
	if(v<val) return insert(lc(x),v,x);
	if(v==val) {cnt(x)++;return x;}
	if(v>val) return insert(rc(x),v,x);
}
void join(int x,int y)
{
	fa(x)=fa(y)=0;
	while(rc(x)) x=rc(x);
	splay(x,0);
	rc(x)=y,fa(y)=x;
}
void del(int x)
{
	x=find1(root,x);
	splay(x,0);
	if(cnt(x)>1)
	{
		cnt(x)--,size(x)--;
		return;
	}
	if(!lc(x)||!rc(x)) fa(root=lc(x)+rc(x))=0;
	else join(lc(x),rc(x));
	lc(x)=rc(x)=0;
}
signed main()
{
	scanf("%lld",&n);
	for(int i=1;i<=n;i++)
	{
		int opt,x;
		scanf("%lld%lld",&opt,&x);
		if(opt==1)
		{
			int node=insert(root,x,0);
			splay(node,0);
		}
		if(opt==2) del(x);
		if(opt==3) printf("%lld\n",find2(root,x,0));
		if(opt==4) printf("%lld\n",val(find3(root,x)));
		if(opt==5) printf("%lld\n",val(find4(root,x,0)));
		if(opt==6) printf("%lld\n",val(find5(root,x,0)));
	}
	return 0;
}
2025/2/6 14:16
加载中...