10pts求调,A最后一个点
查看原帖
10pts求调,A最后一个点
530570
Firrel_qaq楼主2025/7/2 16:51
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cmath>
#define int long long
using namespace std;
struct node{
	int nxt,to;
}mp[400005];
int n,m,r,P,tot,t;
int a[200005],head[200005],f[200005],dep[200005],dfn[200005],rdfn[200005],top[200005],cntson[200005],hyson[200005];
struct tree{
	int l,r,sum,lazy;
}tr[4000005];
void add(int u,int v){
	mp[++tot].to = v;
	mp[tot].nxt = head[u];
	head[u] = tot;
}
void pushdown(int p){
	(tr[p * 2].sum += (tr[p * 2].r - tr[p * 2].l + 1) * tr[p].lazy) % P;
	(tr[p * 2].lazy += tr[p].lazy) % P;
	(tr[p * 2 + 1].sum += (tr[p * 2 + 1].r - tr[p * 2 + 1].l + 1) * tr[p].lazy) % P;
	(tr[p * 2 + 1].lazy += tr[p].lazy) % P;
	tr[p].lazy = 0;
}
void dfs(int x,int fa){
	f[x] = fa;
	cntson[x] = 1;
	int maxx = 0;
	for(int i = head[x];i;i = mp[i].nxt){
		int y = mp[i].to;
		if(y == fa || f[y]) continue;
		dep[y] = dep[x] + 1;
		dfs(y,x);
		cntson[x] += cntson[y];
		if(cntson[y] > maxx) maxx = cntson[y],hyson[x] = y;
	}
}
void dfs1(int x,int tt,int fa){
	dfn[x] = ++t,rdfn[t] = x,top[x] = tt;
	if(hyson[x] == 0) return ;
	dfs1(hyson[x],tt,x);
	for(int i = head[x];i;i = mp[i].nxt){
		int y = mp[i].to;
		if(y == fa || y == hyson[x]) continue;
		dfs1(y,y,x);
	}
}
void build(int l,int r,int p){
	tr[p].l = l,tr[p].r = r;
	if(l == r){
		tr[p].sum = a[rdfn[l]] % P;
		return ;
	}
	int mid = (tr[p].l + tr[p].r) / 2;
	build(tr[p].l,mid,2 * p);
	build(mid + 1,tr[p].r,2 * p + 1);
	tr[p].sum = (tr[p * 2].sum + tr[p * 2 + 1].sum) % P;
//	printf("%lld %lld %lld\n",rdfn[l],rdfn[r],tr[p].sum);
	return ;
}
void change(int p,int l,int r,int k){
	if(l <= tr[p].l && tr[p].r <= r){
		(tr[p].lazy += k) %= P,(tr[p].sum += (tr[p].r - tr[p].l + 1) * k) %= P;
		return ;
	}
	pushdown(p);
	int mid = (tr[p].l + tr[p].r) / 2;
	if(l <= mid) change(p * 2,l,r,k);
	if(r > mid) change(p * 2 + 1,l,r,k);
	tr[p].sum = (tr[p * 2].sum + tr[p * 2 + 1].sum) % P;
	return ;
}
int found(int p,int l,int r){
	if(l <= tr[p].l && tr[p].r <= r) return tr[p].sum;
	pushdown(p);
	int ans = 0,mid = (tr[p].l + tr[p].r) / 2;
	if(l <= mid) (ans += found(p * 2,l,r)) %= P;
	if(r > mid) (ans += found(p * 2 + 1,l,r)) %= P;
	return ans % P;
}
void plus1(int x,int y,int k){
	while(top[x] != top[y]){
		if(dep[x] < dep[y]) swap(x,y);
		change(1,dfn[top[x]],dfn[x],k);
		x = f[top[x]];
	}
	if(dep[x] < dep[y]) swap(x,y);
	change(1,dfn[x],dfn[y],k);
}
int found1(int x,int y){
	int ans = 0;
	while(top[x] != top[y]){
		if(dep[x] < dep[y]) swap(x,y);
		ans = (ans + found(1,dfn[top[x]],dfn[x])) % P;
		x = f[top[x]];
	}
	if(dep[x] < dep[y]) swap(x,y);
	ans = (ans + found(1,dfn[x],dfn[y])) % P;
	return ans % P;
}
signed main(){
	scanf("%lld%lld%lld%lld",&n,&m,&r,&P);
	for(int i = 1;i <= n;i++){
		scanf("%lld",&a[i]);
		a[i] %= P;
	}
	for(int i = 1;i < n;i++){
		int u,v;
		scanf("%lld%lld",&u,&v);
		add(u,v);
		add(v,u);
	}
	dep[r] = 1;
	dfs(r,0);
	dfs1(r,r,0);
	build(1,n,dfn[r]);
	int step,x,y,z;
	for(int i = 1;i <= m;i++){
		scanf("%lld",&step);
		if(step == 1){
			scanf("%lld%lld%lld",&x,&y,&z);
			plus1(x,y,z);
		}
		else if(step == 2){
			scanf("%lld%lld",&x,&y);
			printf("%lld\n",found1(x,y) % P);
		}
		else if(step == 3){
			scanf("%lld%lld",&x,&z);
			change(1,dfn[x],dfn[x] + cntson[x] - 1,z);
		}
		else if(step == 4){
			scanf("%lld",&x);
			printf("%lld\n",found(1,dfn[x],dfn[x] + cntson[x] - 1) % P);
		}
	}
	return 0;
}
2025/7/2 16:51
加载中...