求助整体二分时的值域
查看原帖
求助整体二分时的值域
125901
FxorG楼主2021/5/21 23:31

RT 这是我的代码

#include <bits/stdc++.h>

#define N (int)(3e5+5)
#define ls (cur<<1)
#define rs (ls|1)

using namespace std;

struct node {
	int x,y,val,id;
}a[N<<1];

struct edge {
	int nex,to;
}e[N<<1];

bool vis[N];
int head[N],cnt,tot,fa[N],top[N],dep[N],son[N],id[N],sz[N];
int n,m,pre[N],lsh[N],ans[N],l_cnt;

int rd() {
	int f=1,sum=0; char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
	while(isdigit(ch)) {sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
	return sum*f;
}

void add(int x,int y) {
	e[++cnt]=edge{head[x],y};
	head[x]=cnt;
}

void dfs1(int x,int faa) {
	dep[x]=dep[faa]+1; fa[x]=faa; sz[x]=1;
	for(int i=head[x];i;i=e[i].nex) {
		int y=e[i].to;
		if(y==faa) continue;
		dfs1(y,x); sz[x]+=sz[y];
		if(sz[y]>sz[son[x]]) son[x]=y;
	}
}

void dfs2(int x,int tp) {
	top[x]=tp; id[x]=++tot;
	if(son[x]) dfs2(son[x],tp);
	for(int i=head[x];i;i=e[i].nex) {
		int y=e[i].to;
		if(y==fa[x]||y==son[x]) continue;
		dfs2(y,y);
	}
}

int sum[N<<2];
void add(int cur,int l,int r,int x,int val) {
	if(l==r) {
		sum[cur]+=val;
		return;
	}
	int mid=(l+r)>>1;
	if(x<=mid) add(ls,l,mid,x,val);
	else add(rs,mid+1,r,x,val);
	sum[cur]=sum[ls]+sum[rs];
}

int query(int cur,int l,int r,int cl,int cr) {
	if(cl>cr) return 0;
	if(cl<=l&&r<=cr) return sum[cur];
	int mid=(l+r)>>1,res=0;
	if(cl<=mid) res=query(ls,l,mid,cl,cr);
	if(cr>mid) res+=query(rs,mid+1,r,cl,cr);
	return res;
}

int qry(int x,int y) {
	int res=0;
	while(top[x]!=top[y]) {
		if(dep[top[x]]<dep[top[y]]) swap(x,y);//cout<<x<<" "<<y<<endl;
		res+=query(1,1,n,id[top[x]],id[x]);
		x=fa[top[x]];
	}
	if(id[x]>id[y]) swap(x,y);
	return res+query(1,1,n,id[x],id[y]);
}

node q1[N],q2[N];
void solve(int L,int R,int l,int r) {
	if(l>r) return;
	if(L==R) {
		for(int i=l;i<=r;i++) {
			if(a[i].id==-1) continue;
			ans[a[i].id]=lsh[L];
		//	cout<<lsh[L]<<" "<<a[i].val<<endl;
		}
		//cout<<L<<" "<<l<<" "<<r<<endl;
		return;
	}
	int mid=(L+R)>>1,cnt1=0,cnt2=0;
	for(int i=l;i<=r;i++) {
		if(a[i].id==-1) {
			if(a[i].val<=mid) q1[++cnt1]=a[i];
			else q2[++cnt2]=a[i],add(1,1,n,id[a[i].x],a[i].y);
		} else {
			int res=qry(a[i].x,a[i].y);
			if(res>=a[i].val) q2[++cnt2]=a[i];
			else a[i].val-=res,q1[++cnt1]=a[i];
		}
	}
	for(int i=1;i<=cnt2;i++) if(q2[i].id==-1) add(1,1,n,id[q2[i].x],-q2[i].y);
	for(int i=1;i<=cnt1;i++) a[i+l-1]=q1[i];
	for(int i=1;i<=cnt2;i++) a[i+l+cnt1-1]=q2[i];
	solve(L,mid,l,l+cnt1-1); solve(mid+1,R,l+cnt1,r);
}

int main() {
	n=rd(); m=rd();
	int tot1=0,x,y;
	for(int i=1;i<=n;i++) pre[i]=rd(),a[++tot1]=node{i,1,pre[i],-1},lsh[++l_cnt]=pre[i];
	for(int i=1;i<n;i++) x=rd(),y=rd(),add(x,y),add(y,x);
	dfs1(1,0); dfs2(1,1);	
	for(int i=1;i<=m;i++) {
		x=rd();
		if(x==0) {
			x=rd(); y=rd();
			a[++tot1]=node{x,-1,pre[x],-1};
			a[++tot1]=node{x,1,y,-1};
			lsh[++l_cnt]=pre[x]=y;
		} else {
			a[++tot1]=node{rd(),rd(),x,i};
			vis[i]=1;
		}
	}
	sort(lsh+1,lsh+1+l_cnt);// l_cnt=unique(lsh+1,lsh+1+l_cnt)-lsh-1;
	for(int i=1;i<=tot1;i++) {
		if(a[i].id==-1) a[i].val=lower_bound(lsh+1,lsh+1+l_cnt,a[i].val)-lsh;
	} 
	solve(0,l_cnt,1,tot1);
	for(int i=1;i<=m;i++) {
		if(vis[i]) ans[i]?printf("%d\n",ans[i]):puts("invalid request!");
	}
	return 0;
}
solve(0,l_cnt,1,tot1);

为什么调用值域范围为[0,l_cnt]而不是[1,l_cnt],离散化出来之后不是从1开始的吗? qwq

2021/5/21 23:31
加载中...