Splay求助
查看原帖
Splay求助
365110
xuanyuan_Niubi楼主2021/4/17 10:48

我的想法很清奇,把0全部换成-1之后,最大子段和就是最多连续1的个数,然后原来区间和的话就是区间长度减去现在区间和再除以二后加上现在区间和。 我感觉应该是push_down的问题,每次4操作就会有问题。有人知道这种应该怎么改吗?我是直接把P2042搬过来,以为改一下就过了

#include<cstdio>
#include<iostream>
#include<cstring>
#define swap(a,b) (a^=b^=a^=b)
using namespace std;
const int M=5e5+5;
const int INF=0x3f3f3f3f;
inline int read(){
	char c=getchar();int x=0,f=1;
	for(;c<'0'||c>'9';c=getchar())if(c=='-')f=-1;
	for(;c<='9'&&c>='0';c=getchar())x=(x<<1)+(x<<3)+(c^48);
	return x*f;
}
inline int max(int a,int b){return a>b?a:b;}
int root,tot1,tot2,n,m,a[M],s[M];
//===============================================================
struct Splay_tree{
	int ch[2],val,tag,size,maxn,ls,rs,sum,laz,fa;//ls,rs,maxn都是与维护连续子段最大值有关的 
}t[M];
inline void push_up(int x){
	t[x].size=t[t[x].ch[0]].size+t[t[x].ch[1]].size+1;
	t[x].sum=t[t[x].ch[0]].sum+t[t[x].ch[1]].sum+t[x].val;
	t[x].ls=max(t[t[x].ch[0]].ls,t[t[x].ch[1]].ls+t[t[x].ch[0]].sum+t[x].val);
	t[x].rs=max(t[t[x].ch[1]].rs,t[t[x].ch[0]].rs+t[t[x].ch[1]].sum+t[x].val);
	t[x].maxn=max(max(t[t[x].ch[1]].maxn,t[t[x].ch[0]].maxn),t[t[x].ch[0]].rs+t[t[x].ch[1]].ls+t[x].val);
}
inline void push_down(int x){
	int l=t[x].ch[0],r=t[x].ch[1];
	if(t[x].laz){//维护最大子段 
		t[x].laz=0;
		if(l)t[l].val=t[x].val,t[l].sum=t[l].val*t[l].size,t[l].laz=1;
		if(r)t[r].val=t[x].val,t[r].sum=t[r].val*t[r].size,t[r].laz=1; 
		if(t[x].val>=0){
			if(l)t[l].ls=t[l].rs=t[l].maxn=t[l].sum;
			if(r)t[r].ls=t[r].rs=t[r].maxn=t[r].sum;
		}
		else{
			if(l)t[l].ls=t[l].rs=0,t[l].maxn=t[x].sum;
			if(r)t[r].ls=t[r].rs=0,t[r].maxn=t[x].sum;
		}
	}
	if(t[x].tag){
		t[l].tag^=1,t[r].tag^=1,t[x].tag=0;
		t[x].val=-t[x].val,t[x].sum=-t[x].sum;//反转后sum和val都应该变成相反数 
	}
}
inline void rotate(int x){//基本的旋转和换根 
	int y=t[x].fa,z=t[y].fa,k=(t[y].ch[1]==x);
	t[z].ch[t[z].ch[1]==y]=x,t[x].fa=z;
	t[y].ch[k]=t[x].ch[k^1],t[t[x].ch[k^1]].fa=y;
	t[x].ch[k^1]=y,t[y].fa=x;
	push_up(y),push_up(x);
}
inline void Splay(int x,int goal){
	while(t[x].fa!=goal){
		int y=t[x].fa,z=t[y].fa;
		if(z!=goal){
			(t[y].ch[1]==x)^(t[z].ch[1]==y)?rotate(x):rotate(y);
		}
		rotate(x);
	}
	if(!goal)root=x;
}
inline int find_kth(int x){
	int u=root;
	while(1){
		push_down(u);int y=t[u].ch[0];
		if(x>t[y].size+1){
			x-=t[y].size+1;u=t[u].ch[1];
		}
		else if(x<=t[y].size)u=y;
		else return u;
	}
}
//============================================================
inline int build(int l,int r,int ff){
	if(l>r)return 0;
	int mid=l+r>>1,u=tot2?s[tot2--]:++tot1;
	t[u].fa=ff,t[u].val=a[mid],t[u].tag=t[u].laz=0;
	t[u].ch[0]=build(l,mid-1,u),t[u].ch[1]=build(mid+1,r,u);
	push_up(u);return u;
}
inline void insert(int x,int v){
	int l=find_kth(x+1),r=find_kth(x+2);//找到在树上的位置 
	Splay(l,0),Splay(r,l);
	for(int i=1;i<=v;i++)a[i]=read(); 
	t[r].ch[0]=build(1,v,r),n+=v;//将整颗新加的树插在右儿子的左儿子上,总节点数更新 
	push_up(r),push_up(l);
}
inline void eraser(int x){
	if(!x)return ;
	s[++tot2]=x;//把删除了的放在一个桶里面 
	eraser(t[x].ch[0]),eraser(t[x].ch[1]);
}
inline void reserve(int l,int r){
	l=find_kth(l),r=find_kth(r+2);//找到区间端点在树上的位置 
	Splay(l,0),Splay(r,l);//换根,方便反转 
	int k=t[r].ch[0];
	if(!t[k].laz){//如果还没有翻过 
		t[k].tag^=1;//那么就翻 
		push_up(r),push_up(l);
	}
}
inline void Delete(int l,int r){
	n-=r-l+1;//总数更新,减去减去了的 
	l=find_kth(l),r=find_kth(r+2);//找到位置 
	Splay(l,0),Splay(r,l);
	eraser(t[r].ch[0]);t[r].ch[0]=0;//把这一段区间一下删完
	push_up(r),push_up(l);
}
//============================================================
inline void make_same(int l,int r,int x){
	l=find_kth(l),r=find_kth(r+2);//同样是找到位置 
	Splay(l,0),Splay(r,l);
	int k=t[r].ch[0];//k就是应该修改的区间对应的树 
	t[k].val=x,t[k].sum=t[k].size*x,t[k].laz=1; //改区间 
	if(x<=0)t[k].ls=t[k].rs=0,t[k].maxn=x;
	else t[k].ls=t[k].rs=t[k].maxn=t[k].sum;
	push_up(r),push_up(l);
}
inline void write(int x){//将区间输出,从左到右 
	if(x){
		push_down(x);
		write(t[x].ch[0]);
		if(t[x].val!=-INF&&t[x].val!=INF)printf("%d ",t[x].val);
		write(t[x].ch[1]);
	}
}
inline int query_sum(int l,int r){//找区间的和,已经处理出来了,就是这一段对应子树的sum 
	l=find_kth(l),r=find_kth(r+2);
	Splay(l,0),Splay(r,l);
	return t[t[r].ch[0]].sum;
}
inline int query_max(int l,int r){//对应子树的maxn就是最大的子段和 
	l=find_kth(l),r=find_kth(r+2);
	Splay(l,0),Splay(r,l);
	return t[t[r].ch[0]].maxn;
}
int main(){
	n=read(),m=read();
	t[0].maxn=a[1]=-INF,a[n+2]=INF;
	for(int i=1;i<=n;i++){
		a[i+1]=read()?1:-1;//如果是0就是-1 这样连续子段最大值就是最多的连续1的个数 
	}
	root=build(1,n+2,0);
	for(int i=1;i<=m;i++){
		int opt=read(),l=read()+1,r=read()+1;
		if(opt==0)make_same(l,r,-1);
		else if(opt==1)make_same(l,r,1);
		else if(opt==2)reserve(l,r);
		else if(opt==3){
			int len=r-l+1,sum=query_sum(l,r);//想象一下,1的个数就是总个数减去1比-1多的个数除以二后再加上1比-1多的个数 
			printf("%d\n",(len-sum)/2+sum);
		}
		else printf("%d\n",query_max(l,r));
//		write(root);printf("\n"); 
	}
	return 0;
}
2021/4/17 10:48
加载中...