为什么我的 myfind
函数写成这样子就会 TLE
,而下面那样子就可以 AC
。
int myfind(int x,int v)
{
if(!lc(x)&&!rc(x)) return x;//如果它是叶子节点就返回
down(x);
int size=size(lc(x)),cnt=cnt(x);
if(v<=size) return myfind(lc(x),v);
else if(v>size+cnt) return myfind(rc(x),v-size-cnt);
else return x;
}
int myfind(int x,int v)
{
//if(!lc(x)&&!rc(x)) return x;
down(x);
int size=size(lc(x)),cnt=cnt(x);
if(v<=size) return myfind(lc(x),v);
else if(v>size+cnt) return myfind(rc(x),v-size-cnt);
else return x;
}
这两个代码唯一的区别就在于有没有特判叶子节点。
顺便附上一整个代码:
#include<bits/stdc++.h>
#define int long long
#define lc(x) tr[x].ch[0]
#define rc(x) tr[x].ch[1]
#define fa(x) tr[x].fa
#define val(x) tr[x].val
#define pos(x) tr[x].pos
#define size(x) tr[x].size
#define cnt(x) tr[x].cnt
#define check(x) (rc(fa(x))==x)
#define pushup(x) size(x)=size(lc(x))+size(rc(x))+cnt(x)
using namespace std;
const int N=1e5+100;
int n,m,tot,root;
struct tree{
int fa,ch[2],val,cnt,size;bool pos;
}tr[N];
void rotate(int x)
{
int y=fa(x),z=fa(y);
int k=check(x),w=tr[x].ch[k^1];
if(z) tr[z].ch[check(y)]=x;
fa(x)=z;
tr[y].ch[k]=w,fa(w)=y;
tr[x].ch[k^1]=y,fa(y)=x;
pushup(x),pushup(y);
}
void splay(int x,int target)
{
while(fa(x)!=target)
{
int y=fa(x),z=fa(y);
if(z!=target)
if(check(x)==check(y)) rotate(y);
else rotate(x);
rotate(x);
}
if(!target) root=x;
}
int insert(int x,int v,int y)
{
if(!x)
{
tr[++tot]=(tree){y,{0,0},v,1,1,0};
if(y) tr[y].ch[v>val(y)]=tot;
return tot;
}
int val=val(x);
size(x)++;
if(v<val) return insert(lc(x),v,x);
else if(v>val) return insert(rc(x),v,x);
else if(v==val)
{
cnt(x)++;
return x;
}
}
void down(int x)
{
bool pos=pos(x);
if(!pos) return;
swap(lc(x),rc(x));
pos(lc(x))^=1,pos(rc(x))^=1;
pos(x)=0;
}
int myfind(int x,int v)
{
//if(!lc(x)&&!rc(x)) return x;
down(x);
int size=size(lc(x)),cnt=cnt(x);
if(v<=size) return myfind(lc(x),v);
else if(v>size+cnt) return myfind(rc(x),v-size-cnt);
else return x;
}
void print(int x)
{
if(!x) return;
down(x);
print(lc(x));
if(x>1&&x<n+2) printf("%lld ",val(x)-1);
print(rc(x));
}
signed main()
{
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n+2;i++)
{
insert(root,i,0);
splay(i,0);
}
for(int i=1;i<=m;i++)
{
int l,r;
scanf("%lld%lld",&l,&r);
int L=myfind(root,l),R=myfind(root,r+2);
splay(L,0);
splay(R,L);
pos(lc(rc(root)))^=1;
}
print(root);
return 0;
}