60分splay,有re有wa,救命啊
查看原帖
60分splay,有re有wa,救命啊
398312
国王的新账号楼主2021/8/23 23:18

1,7,8 wa 3 re 求助呜呜呜呜

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
int rt,num,n,ans,an;
int val[maxn],sz[maxn],son[maxn][2],fa[maxn],cnt[maxn];
void pushup(int x)
{
	sz[x]=sz[son[x][0]]+sz[son[x][1]]+cnt[x];
}
void edit(int x,int f)
{
	num++;
	son[num][1]=son[num][0]=0;
	val[num]=x;
	sz[num]=1;
	cnt[num]=1;
	fa[num]=f;
}
void clear(int x)
{
	son[x][0]=son[x][1]=val[x]=sz[x]=cnt[x]=val[x]=0;
}
bool get(int x)
{
	return x==son[fa[x]][1];
}
void rotate(int x)
{
	int y=fa[x],z=fa[y];
	bool side=get(x);
	son[y][side]=son[x][side^1],fa[son[y][side]]=y;
	son[x][side^1]=y,fa[y]=x,fa[x]=z;
	if(z)	son[z][y==son[z][1]]=x;
	pushup(y);
	pushup(x);
}
void splay(int x,int tag)
{
	for(int y;(y=fa[x])!=tag;rotate(x)){
		if(fa[y]!=tag)	rotate(get(x)^get(y)?x:y);
	} 
	if(tag==0)	rt=x;
}
void insert(int x)
{
	if(rt==0){
		edit(x,0);
		rt=num;
		return ;
	}
	if(val[rt]==x){
		cnt[rt]++;
		pushup(rt);
		return ;
	}
	int now=rt,f=0;
	while(1){
		f=now,now=son[now][val[now]<x];
		if(now==0){
			edit(x,f);
			son[f][val[f]<x]=num;
			pushup(f);
			splay(num,0);
			return ;
		}
		if(val[now]==x){
			cnt[now]++;
			pushup(now);
			pushup(f);
			splay(now,0);
			return ;
		}
	}
}
int lower()
{
	int now=son[rt][0];
	if(!now)	return -1;
	while(son[now][1])	now=son[now][1];
	return val[now];
}
int upper()
{
	int now=son[rt][1];
	if(!now)	return -1;
	while(son[now][0])	now=son[now][0];
	return val[now];
}
int find(int x)
{
	int now=rt;
	while(val[now]!=x&&now!=0)	now=son[now][val[now]<x];
	return now;
}
void del(int x)
{
	int now=find(x);
	splay(now,0);
	if(cnt[rt]>1){
		cnt[rt]--;
		pushup(rt);
		return ;
	}
	if(!son[rt][0]&&!son[rt][1]){
		clear(rt);
		rt=0;
		return ;
	}
	if(!son[rt][0]){
		rt=son[rt][1];
		clear(fa[rt]);
		fa[rt]=0;
		return ;
	}else if(!son[rt][1]){
		rt=son[rt][0];
		clear(fa[rt]);
		fa[rt]=0;
		return ;
	}
	int his=rt;
	splay(lower(),0);
	son[rt][1]=son[his][1];
	fa[son[his][1]]=rt;
	pushup(rt);
	clear(his);
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		int k,x;
		scanf("%d%d",&k,&x);
		if(an==0)	insert(x);
		else if((k==1&&an>0)||(k==0&&an<0))	insert(x);
		else{
			insert(x);
			splay(find(x),0);
			if(cnt[rt]>1){
				del(x);
				del(x);
			}else{
				int lw=lower(),up=upper();
				del(x);
				int mlw=x-lw,mup=up-x;
				if(lw==-1){
					ans+=mup;
					ans%=1000000;
					del(up);
				}else if(up==-1){
					ans+=mlw;
					ans%=1000000;
					del(lw);
				}else if(mlw<=mup){
					ans+=mlw;
					ans%=1000000;
					del(lw);
				}else if(mlw>mup){
					ans+=mup;
					ans%=1000000;
					del(up);
				}
			} 
		}
		if(k==0)	an--;
		else if(k==1)	an++;
	}
	printf("%d",ans);
	return 0;
}
2021/8/23 23:18
加载中...