30pts 求助
查看原帖
30pts 求助
299756
Surget楼主2021/6/30 13:33

求dalao帮调

诚心求教

#include<cstdio>
#include<iostream>
#define ll long long
using namespace std;
int n,m,r=1,tot=1,sum=0;
int val[100012],h[100012],d[100012],w[100012],son[100012],fa[100012],sz[100012],top[100012],id[100012],tree[4*100012],mark[4*100012];
int fa_[500012][20];
bool vis[100012];
struct str
{
	int next,to;
}e[600012*2];
void add(int x,int y)
{
	e[++tot].next=h[x];
	e[tot].to=y;
	h[x]=tot;
}
void dfs1(int now,int f,int dep)
{
	d[now]=dep+1;
	sz[now]=1;
	fa[now]=f;
	for(int i=h[now];i;i=e[i].next)
	{
		int v=e[i].to;
		if(v!=f)
		{
			dfs1(v,now,dep+1);
			sz[now]+=sz[v];
			if(sz[v]>sz[son[now]])
			{
				son[now]=v;
			}
		}
	}
}
void dfs2(int now,int f)
{
	if(now==0)
	{
		return ;
	}
	vis[now]=true;
	top[now]=f;
	id[now]=++sum;
	w[sum]=val[now];
	dfs2(son[now],f);
	for(int i=h[now];i;i=e[i].next)
	{
		int v=e[i].to;
		if(!vis[v])
		{
			dfs2(v,v);
		}
	}
}
void build(int now,int l,int r)
{
	if(l==r)
	{
		tree[now]=w[l];
		return ;
	}
	int mid=(l+r)/2;
	build(now*2,l,mid);
	build(now*2+1,mid+1,r);
	tree[now]=tree[now*2]+tree[now*2+1];
}
void push_down(int l,int r,int now)
{
	if(mark[now])
	{
		int mid=(l+r)/2;
		mark[now*2]+=mark[now];
		mark[now*2+1]+=mark[now];
		tree[now*2]+=((mid-l+1)*mark[now]);
		tree[now*2+1]+=((r-mid)*mark[now]);
		mark[now]=0;
	}
}
void update(int l,int r,int now,int ql,int qr,ll add)
{
	if(ql<=l&&qr>=r)
	{
		mark[now]+=add;
		tree[now]+=(r-l+1)*add;
		return;
	}
	push_down(l,r,now);
	int mid=(l+r)/2;
	if(mid>=ql)
	{
		update(l,mid,now*2,ql,qr,add);
	}
	if(mid<qr)
	{
		update(mid+1,r,now*2+1,ql,qr,add);
	}
	tree[now]=tree[now*2]+tree[now*2+1];
}
ll query(int l,int r,int now,int ql,int qr)
{
	if(ql<=l&&qr>=r) 
	{
		return tree[now];
	}
	push_down(l,r,now);
	int mid=(l+r)/2;
	ll ans=0;
	if(mid+1<=qr)
	{
		ans+=query(mid+1,r,now*2+1,ql,qr);
	}
	if(mid>=ql)
	{
		ans+=query(l,mid,now*2,ql,qr);
	}
	return ans;
}
int lca(int x,int y)
{
	while(top[x]!=top[y])
	{
		int tx=top[x],ty=top[y];
		if(d[tx]<d[ty])
		{
			swap(tx,ty);
			swap(x,y);
		}
		x=fa[tx];
	}
	if(d[x]>d[y])
	{
		swap(x,y);
	}
	return x;
}
int find(int x,int y){
	while(top[x]!=top[y])
	{
		int tx=top[x],ty=top[y];
		if(d[tx]<d[ty])
		{
			swap(tx,ty);
			swap(x,y);
		}
		if(fa[tx]==y)
		{
			return tx;
		}
		x=fa[tx];
	}
	if(d[x]>d[y])
	{
		swap(x,y);
	}
	return son[x];
}
int LCA(int x,int y)
{
	if(d[x]>d[y])
	{
		swap(x,y);
	}
	if(lca(x,y)==x)
	{
		if(id[r]>=id[y]&&id[r]<=id[y]+sz[y]-1)
		{
			return y;
		}
		if(lca(x,r)==x)
		{
			return lca(y,r);
		}
		return x;
	}
	if(id[r]>=id[x]&&id[r]<=id[x]+sz[x]-1)
	{
		return x;
	}
	if(id[r]>=id[y]&&id[r]<=id[y]+sz[y]-1)
	{
		return y;
	}
	if((lca(x,r)==r&&lca(x,y)==lca(y,r))||(lca(y,r)==r&&lca(x,y)==lca(x,r)))
	{
		return r;
	}
	if(lca(x,r)==lca(y,r))
	{
		return lca(x,y);
	}
	if(lca(x,y)!=lca(x,r))
	{
		return lca(x,r);
	}
	return lca(y,r);
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
	{
		scanf("%d",&val[i]);
	}
	for(int i=1;i<n;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add(x,y);
		add(y,x);
	}
	dfs1(r,0,1);
	dfs2(r,r);
	build(1,1,n);
	for(int i=1;i<=m;i++)
	{
		int type;
		scanf("%d",&type);
		if(type==1)
		{
			scanf("%d",&r);
			continue;
		}
		if(type==2)
		{
			int x,y;
			ll z;
			scanf("%d%d%lld",&x,&y,&z);
			int f=LCA(x,y);
			if(f==r)
			{
				update(1,n,1,1,n,z);
				continue;
			}
			int f1=lca(f,r);
			if(f1!=f)
			{
				update(1,n,1,id[f],id[f]+sz[f]-1,z);
				continue;
			}
			int s=find(f,r);
			update(1,n,1,1,n,z);
			update(1,n,1,id[s],id[s]+sz[s]-1,-z);
		}
		if(type==3)
		{
			int x;
			scanf("%d",&x);
			if(x==r)
			{
				printf("%lld\n",query(1,n,1,1,n));
				continue;
			}
			int f=lca(x,r);
			if(x!=f)
			{
				printf("%lld\n",query(1,n,1,id[x],id[x]+sz[x]-1));
				continue;
			}
			int s=find(x,r);
			printf("%lld\n",query(1,n,1,1,n)-query(1,n,1,id[s],id[s]+sz[s]-1));
		}
	}
	return 0;
}
2021/6/30 13:33
加载中...