疑问(玄关)
查看原帖
疑问(玄关)
1036707
chzhh_111楼主2025/2/6 08:37

为什么我的 myfind 函数写成这样子就会 TLE,而下面那样子就可以 AC

int myfind(int x,int v)
{
	if(!lc(x)&&!rc(x)) return x;//如果它是叶子节点就返回
	down(x);
	int size=size(lc(x)),cnt=cnt(x);
	if(v<=size) return myfind(lc(x),v);
	else if(v>size+cnt) return myfind(rc(x),v-size-cnt);
	else return x;
}
int myfind(int x,int v)
{
	//if(!lc(x)&&!rc(x)) return x;
	down(x);
	int size=size(lc(x)),cnt=cnt(x);
	if(v<=size) return myfind(lc(x),v);
	else if(v>size+cnt) return myfind(rc(x),v-size-cnt);
	else return x;
}

这两个代码唯一的区别就在于有没有特判叶子节点。

顺便附上一整个代码:

#include<bits/stdc++.h>
#define int long long
#define lc(x) tr[x].ch[0]
#define rc(x) tr[x].ch[1]
#define fa(x) tr[x].fa
#define val(x) tr[x].val
#define pos(x) tr[x].pos
#define size(x) tr[x].size
#define cnt(x) tr[x].cnt
#define check(x) (rc(fa(x))==x)
#define pushup(x) size(x)=size(lc(x))+size(rc(x))+cnt(x)
using namespace std;
const int N=1e5+100;
int n,m,tot,root;
struct tree{
	int fa,ch[2],val,cnt,size;bool pos;
}tr[N];
void rotate(int x)
{
	int y=fa(x),z=fa(y);
	int k=check(x),w=tr[x].ch[k^1];
	if(z) tr[z].ch[check(y)]=x;
	fa(x)=z;
	tr[y].ch[k]=w,fa(w)=y;
	tr[x].ch[k^1]=y,fa(y)=x;
	pushup(x),pushup(y);
}
void splay(int x,int target)
{
	while(fa(x)!=target)
	{
		int y=fa(x),z=fa(y);
		if(z!=target)
		  if(check(x)==check(y)) rotate(y);
		    else rotate(x);
		rotate(x);
	}
	if(!target) root=x;
}
int insert(int x,int v,int y)
{
	if(!x)
	{
		tr[++tot]=(tree){y,{0,0},v,1,1,0};
		if(y) tr[y].ch[v>val(y)]=tot;
		return tot;
	}
	int val=val(x);
	size(x)++;
	if(v<val) return insert(lc(x),v,x);
	else if(v>val) return insert(rc(x),v,x);
	else if(v==val)
	{
		cnt(x)++;
		return x;
	}
}
void down(int x)
{
	bool pos=pos(x);
	if(!pos) return;
	swap(lc(x),rc(x));
	pos(lc(x))^=1,pos(rc(x))^=1;
	pos(x)=0;
}
int myfind(int x,int v)
{
	//if(!lc(x)&&!rc(x)) return x;
	down(x);
	int size=size(lc(x)),cnt=cnt(x);
	if(v<=size) return myfind(lc(x),v);
	else if(v>size+cnt) return myfind(rc(x),v-size-cnt);
	else return x;
}
void print(int x)
{
	if(!x) return;
	down(x);
	print(lc(x));
	if(x>1&&x<n+2) printf("%lld ",val(x)-1);
	print(rc(x));
}
signed main()
{
	scanf("%lld%lld",&n,&m);
	for(int i=1;i<=n+2;i++)
	{
		insert(root,i,0);
		splay(i,0);
	}
	for(int i=1;i<=m;i++)
	{
		int l,r;
		scanf("%lld%lld",&l,&r);
		int L=myfind(root,l),R=myfind(root,r+2);
		splay(L,0);
		splay(R,L);
		pos(lc(rc(root)))^=1;
	}
	print(root);
	return 0;
}
2025/2/6 08:37
加载中...