10pt,求改
查看原帖
10pt,求改
109270
seven_sin楼主2021/9/13 21:10

rt

#include "iostream"
#include "cstdio"
#include "cstring"
#include "cmath"
#include "algorithm"
#include "vector"
using namespace std;
const int maxn=50010;
vector<int >edg[maxn];
int root,mod,n,top[maxn],son[maxn],size[maxn],val[maxn],id[maxn],fa[maxn],deep[maxn],a[maxn],cnt;
struct node{
int csum[maxn<<2],lazy[maxn<<2];
	inline int tl(int x){return x<<1;}
	inline int tr(int x){return x<<1|1;}
	inline void push_up(int x){csum[x]=csum[tl(x)]+csum[tr(x)];csum[x]%=mod;}
	inline void push_down(int x){lazy[tl(x)]+=lazy[x];lazy[tr(x)]+=lazy[x];lazy[x]=0;}
	inline void bulid(int k,int l,int r){
	if(l==r){csum[k]=a[val[l]];return ;}
	int mid=(l+r)>>1;bulid(tl(k),l,mid);bulid(tr(k),mid+1,r);
	push_up(k);
	 }
	inline int query_sum(int k,int l,int r,int left,int right){
	if(left==l&&r==right)return (csum[k]+(r-l+1)*lazy[k])%mod;
	int mid=(left+right)>>1;
	if(lazy[k])push_down(k);
	if(r<=mid)return query_sum(tl(k),l,r,left,mid)%mod;
	else if(l>mid) return query_sum(tr(k),l,r,mid+1,right)%mod;
	else return (query_sum(tl(k),l,mid,left,mid)+query_sum(tr(k),mid+1,r,mid+1,right))%mod;	
	}
	inline void update(int k,int l,int r,int left,int right,int c){
    if(left==l&&r==right){lazy[k]+=c;lazy[k]%=mod;return ;}
    int mid=(left+right>>1);
    if(lazy[k])push_down(k);
    if(r<=mid)return update(tl(k),l,r,left,mid,c);
	else if(l>mid) return update(tr(k),l,r,mid+1,right,c);
	else return update(tl(k),l,mid,left,mid,c),update(tr(k),mid+1,r,mid+1,right,c);
	}  
}tree;
inline void dfs1(int k){
	size[k]=1;
	for(int i=0;i<edg[k].size();i++){
	int u=edg[k][i];if(deep[u])continue;deep[u]=deep[k]+1;fa[u]=k;
	dfs1(u);
	size[k]+=size[u];if(!son[k]||size[son[k]]<size[u])son[k]=u;	
	}
}
inline void dfs2(int k,int tp){
	cnt++;id[k]=cnt;val[cnt]=k;top[k]=tp;
	if(son[k])dfs2(son[k],tp);
	for(int i=0;i<edg[k].size();i++){
	int u=edg[k][i];if(u==fa[k]||u==son[k])continue;	
	dfs2(u,u);	
	}
}
inline int qsum(int x,int y){
	int ans=0;
	while(top[x]!=top[y]){
	if(deep[top[x]]<deep[top[y]])swap(x,y);	
	ans+=tree.query_sum(1,id[top[x]],id[x],1,n); 
	x=fa[top[x]];
	}
	if(deep[x]>deep[y])swap(x,y);
	ans+=tree.query_sum(1,id[x],id[y],1,n);
	return ans;
}
inline void qrange(int x,int y,int z){
	while(top[x]!=top[y]){
	if(deep[top[x]]<deep[top[y]])swap(x,y);	
	tree.update(1,id[top[x]],id[x],1,n,z); 
	x=fa[top[x]];
	}
	if(deep[x]>deep[y])swap(x,y);
	tree.update(1,id[x],id[y],1,n,z);
}
inline int tsum(int x){int ans=tree.query_sum(1,id[x],id[x]+size[x]-1,1,n);return ans;}
inline int trange(int x,int y){tree.update(1,id[x],id[x]+size[x]-1,1,n,y);}
int main()
{
	int q;
	scanf("%d%d%d%d",&n,&q,&root,&mod);
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
	for(int i=1;i<n;i++){int u,v;scanf("%d%d",&u,&v);edg[u].push_back(v);edg[v].push_back(u);}
	deep[root]=1;dfs1(root);dfs2(root,root);tree.bulid(1,1,n);
	while(q--){
	int opt,u,v,w;scanf("%d",&opt);
	if(opt==1){scanf("%d%d%d",&u,&v,&w);qrange(u,v,w);}
	if(opt==2){scanf("%d%d",&u,&v);printf("%d\n",qsum(u,v));}
	if(opt==3){scanf("%d%d",&u,&v);trange(u,v);}
	if(opt==4){scanf("%d",&u);printf("%d\n",tsum(u));}
	}
	return 0;
}
2021/9/13 21:10
加载中...