求dl答疑解惑
查看原帖
求dl答疑解惑
1336631
enter_prise楼主2025/7/30 20:58

为什么求答案的get_t函数里必须用这种带LCA的办法,而用注释里树剖模板的方式去求会炸呢?

#include<bits/stdc++.h>
using namespace std;
const int MAXN=100005;
struct wzy{
	int to;
	int nxt;
}edge[MAXN<<1];
struct xjc{
	int l;
	int r;
	int lval;
	int rval;
	int len;
	int ans;
	int tag;
}t[MAXN<<2];
int ccb,n,m,cnt=0,head[MAXN];
int dep[MAXN],dfn[MAXN],rnk[MAXN],fa[MAXN],size[MAXN],top[MAXN],hson[MAXN];
xjc zero(xjc x){
	x.ans=x.l=x.len=x.lval=x.r=x.rval=x.tag=0;
	return x;
}
void build114514(int a,int b){
	edge[++cnt].to=b;
	edge[cnt].nxt=head[a];
	head[a]=cnt;
	return ;
}
xjc merge(xjc a,xjc b){
	xjc num;
	num.lval=a.lval;
	num.rval=b.rval;
	num.ans=a.ans+b.ans;
	if(a.rval==b.lval) num.ans++;
	num.len=a.len+b.len;
	return num;
}
void pushup(int p){
	t[p]=merge(t[p<<1],t[p<<1|1]);
	t[p].l=t[p<<1].l;
	t[p].r=t[p<<1|1].r;
	t[p].tag=0;
	return ; 
}
void add(int p,int h){
	t[p].ans=t[p].len-1;
	t[p].lval=h;
	t[p].rval=h;
	t[p].tag=h;
	return ;
}
void pushdown(int p){
	if(t[p].tag==0) return ;
	add(p<<1,t[p].tag);
	add(p<<1|1,t[p].tag);
	t[p].tag=0;
	return ;
}
void dfs1(int u,int f){
	dep[u]=dep[f]+1;
	size[u]=1;
	fa[u]=f;
	int maxnum=0;
	for(int i=head[u];i!=-1;i=edge[i].nxt){
		int v=edge[i].to;
		if(v==f) continue;
		dfs1(v,u);
		if(size[v]>maxnum){
			maxnum=size[v];
			hson[u]=v;
		}
		size[u]+=size[v];
	}
	return ;
}
void dfs2(int u,int topf){
	dfn[u]=++cnt;
	rnk[cnt]=u;
	top[u]=topf;
	if(hson[u]==0) return ;
	dfs2(hson[u],topf);
	for(int i=head[u];i!=-1;i=edge[i].nxt){
		int v=edge[i].to;
		if(v==fa[u]||v==hson[u]) continue;
		dfs2(v,v);
	}
	return ;
}
int lca(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		x=fa[top[x]];
	}
	if(dep[x]<dep[y]) return x;
	else return y;
}
void build(int p,int l,int r){
	t[p].l=l;
	t[p].r=r;
	t[p].len=r-l+1;
	t[p].ans=0;
	t[p].tag=0;
	if(l==r){
		t[p].lval=l;
		t[p].rval=l;
		return ;
	}
	int mid=(l+r)>>1;
	build(p<<1,l,mid);
	build(p<<1|1,mid+1,r);
	pushup(p);
	return ;
}
void change(int p,int L,int R,int h){
	if(t[p].l>=L&&t[p].r<=R){
		add(p,h);
		return ;
	}
	pushdown(p);
	int mid=(t[p].l+t[p].r)>>1;
	if(L<=mid) change(p<<1,L,R,h);
	if(R>mid) change(p<<1|1,L,R,h);
	pushup(p);
	return ;
}
xjc get(int p,int L,int R){
	if(t[p].l>=L&&t[p].r<=R){
		return t[p];
	}
	pushdown(p);
	int mid=(t[p].l+t[p].r)>>1;
	if(R<=mid) return get(p<<1,L,R);
	if(L>mid) return get(p<<1|1,L,R);
	return merge(get(p<<1,L,R),get(p<<1|1,L,R));
}
void change_t(int x,int y,int h){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		change(1,dfn[top[x]],dfn[x],h);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	change(1,dfn[x],dfn[y],h);
	return ;
}
int get_t(int x,int y){
	xjc al,al2;
	al=zero(al);
	al2=zero(al2);
	int man=lca(x,y);
	/*while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		al=merge(get(1,dfn[top[x]],dfn[x]),al);
		al=merge(get(1,dfn[fa[top[x]]],dfn[top[x]]),al);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	if(x!=y) al=merge(get(1,dfn[x],dfn[y]),al);*/
	while(top[x]!=top[man]){
		al=merge(get(1,dfn[top[x]],dfn[x]),al);
		x=fa[top[x]];
	}
	al=merge(get(1,dfn[man],dfn[x]),al);
	while(top[y]!=top[man]){
		al2=merge(get(1,dfn[top[y]],dfn[y]),al2);
		y=fa[top[y]];
	}
	al2=merge(get(1,dfn[man],dfn[y]),al2);
	return al.ans+al2.ans;
}
int main(){
	scanf("%d",&ccb);
	for(int z=1;z<=ccb;z++){
		memset(head,-1,sizeof(head));
		memset(edge,0,sizeof(edge));
		//memset(top,0,sizeof(top));
		memset(hson,0,sizeof(hson));
		scanf("%d %d",&n,&m);
		cnt=0;
		for(int i=1,u,v;i<=n-1;i++){
			scanf("%d %d",&u,&v);
			build114514(u,v);
			build114514(v,u);
		}
		dfs1(1,0);
		cnt=0;
		dfs2(1,1);
		build(1,1,n);
		for(int i=1,op,a,b;i<=m;i++){
			scanf("%d %d %d",&op,&a,&b);
			if(op==1){
				change_t(a,b,i+n);
			}
			if(op==2) printf("%d\n",get_t(a,b));
		}
	}
	return 0;
}
2025/7/30 20:58
加载中...