RE求助
查看原帖
RE求助
394991
Sharing666楼主2021/7/12 19:32
#include<bits/stdc++.h>
using namespace std;
#define int long long

int n,m,lc,rc,cnt,num,res,ans,w[200002],val[200002];
int sz[200002],top[200002],son[200002],fa[200002],id[200002];
int dep[200002],lazy[400002],sum[400002],head[200002];

struct node{
	int lc,rc;
}a[400002];

struct edge{
	int to,nxt;
}e[200002];

void addedge(int A,int B) {
	e[++cnt].to=B;
	e[cnt].nxt=head[A];
	head[A]=cnt;
}

void dfs1(int u,int last) {
	dep[u]=dep[last]+1;
	fa[u]=last;
	sz[u]=1;
	int ma=-1;
	for(int i=head[u];i;i=e[i].nxt) {
		int v=e[i].to;
		if(v==last) continue;
		dfs1(v,u);
		sz[u]+=sz[v];
		if(sz[v]>ma) {
			ma=sz[v];
			son[u]=v;
		}
	}
}

void dfs2(int u,int rt) {
	id[u]=++num;
	top[u]=rt;
	if(!son[u]) return ;
	dfs2(son[u],rt);
	for(int i=head[u];i;i=e[i].nxt) {
		int v=e[i].to;
		if(v==fa[u] || v==son[u]) continue;
		dfs2(v,v);
	}
}

void pushup(int num) {
	a[num].lc=a[num*2].lc;
	a[num].rc=a[num*2+1].rc;
	sum[num]=sum[num*2]+sum[num*2+1];
	if(a[num*2].rc==a[num*2+1].lc) sum[num]--;
}

void build(int num,int l,int r) {
	sum[num]=0;
	if(l==r) {
		sum[num]=1;
		a[num].lc=a[num].rc=val[l];
		return;
	}
	int mid=(l+r)/2;
	build(num*2,l,mid);
	build(num*2+1,mid+1,r);
	pushup(num);
}

void pushdown(int num,int l,int r) {
	if(!lazy[num]) return ;
	lazy[num*2]=lazy[num*2+1]=sum[num*2]=sum[num*2+1]=1;
	a[num*2].lc=a[num*2].rc=a[num*2+1].lc=a[num*2+1].rc=a[num].lc;
	lazy[num]=0;
}

void update(int num,int l,int r,int ll,int rr,int x) {
	if(l>=ll && r<=rr) {
		sum[num]=lazy[num]=1;
		a[num].lc=a[num].rc=x;
		return ;
	}
	pushdown(num,l,r);
	int mid=(l+r)/2;
	if(ll<=mid) update(num*2,l,mid,ll,rr,x);
	if(rr>mid) update(num*2+1,mid+1,r,ll,rr,x);
	pushup(num);
}

int query(int num,int l,int r,int ll,int rr) {
	if(l==ll) lc=a[num].lc;
	if(r==rr) rc=a[num].rc;
	if(l>=ll && r<=rr) return sum[num];
	pushdown(num,l,r);
	res=0;
	int mid=(l+r)/2;
	if(rr<=mid) return query(num*2,l,mid,ll,rr);
	if(ll>mid) return query(num*2+1,mid+1,r,ll,rr);
	res+=query(num*2,l,mid,ll,rr);
	res+=query(num*2+1,mid+1,r,ll,rr);
	if(a[num*2].rc==a[num*2+1].lc) res--;
	return res;
}

void range1(int x,int y) {
	ans=0;
	int LC=0,RC=0;
	while(top[x]!=top[y]) {
		if(dep[top[x]]<dep[top[y]]) {
			swap(x,y);
			swap(LC,RC);
		}
		res=0;
		ans+=query(1,1,n,id[top[x]],id[x]);
		if(LC==rc) ans--;
		LC=lc;
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) {
		swap(x,y);
		swap(LC,RC);
	}
	ans+=query(1,1,n,id[y],id[x]);
	if(LC==rc) ans--;
	if(RC==lc) ans--;
}

void range2(int x,int y,int z) {
	while(top[x]!=top[y]) {
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		update(1,1,n,id[top[x]],id[x],z);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	update(1,1,n,id[x],id[y],z);
}

signed main() {
	scanf("%lld%lld",&n,&m);
	for(int i=1;i<=n;i++) scanf("%lld",&w[i]);
	for(int i=1;i<n;i++) {
		int u,v;
		scanf("%lld%lld",&u,&v);
		addedge(u,v);
		addedge(v,u);
	}
	dfs1(1,0);
	dfs2(1,1);
	for(int i=1;i<=n;i++) val[id[i]]=w[i];
	build(1,1,n);
	while(m--) {
		char op;
		int x,y,z;
		cin>>op;
		if(op=='C') {
			scanf("%lld%lld%lld",&x,&y,&z);
			range2(x,y,z);
		} else if(op=='Q'){
			scanf("%lld%lld",&x,&y);
			range1(x,y);
			printf("%lld\n",ans);
		}
	}
	return 0;
}
2021/7/12 19:32
加载中...