求大佬看看树链剖分模板哪里错了
  • 板块灌水区
  • 楼主大雾山上
  • 当前回复20
  • 已保存回复20
  • 发布时间2019/5/15 22:52
  • 上次更新2024/8/9 20:05:22
查看原帖
求大佬看看树链剖分模板哪里错了
118733
大雾山上楼主2019/5/15 22:52
#include<bits/stdc++.h>
using namespace std;
const int M=1000001;
long long n,m,k,tot,cnt,MOD,a[M];
int head[M],size[M],fa[M],dep[M];
int son[M],seg[M],rev[M],top[M];
struct Edge
{
	int next,to;
}edge[1000001];
struct Tree
{
	int l,r;
	long long add,sum;
	#define l(x) tree[x].l
	#define r(x) tree[x].r
	#define add(x) tree[x].add
	#define sum(x) tree[x].sum
}tree[4000001];
void addd(int from,int to)
{
	edge[++tot].next=head[from];
	edge[tot].to=to;
	head[from]=tot;
}
void build(int p,int l,int r)
{
	l(p)=l;
	r(p)=r;
	if(l==r) 
	{
		sum(p)=a[rev[l]];
		return;
	}
	int mid=(l+r)/2;
	build(p*2,l,mid);
	build(p*2+1,mid+1,r);
	sum(p)=sum(p*2)+sum(p*2+1);
}
void spread(int p)
{
	if(add(p))
	{
		sum(p*2)+=(long long)add(p)*(r(p*2)-l(p*2)+1);
		sum(p*2+1)+=(long long)add(p)*(r(p*2+1)-l(p*2+1)+1);
		add(p*2)+=add(p);
		add(p*2+1)+=add(p);
		add(p)=0;
	}
}
void change(int l,int r,int p,int d)
{
	if(l<=l(p)&&r>=r(p))
	{
		sum(p)+=d*(r(p)-l(p)+1);
		add(p)+=d;
		return;
	}
	spread(p);
	int mid=(l(p)+r(p))/2;
	if(l<=mid) change(p*2,l,r,d);
	if(r>mid) change(p*2+1,l,r,d);
	sum(p)=sum(p*2)+sum(p*2+1);
}
long long ask(int l,int r,int p)
{
	if(l<=l(p)&&r>=r(p)) return sum(p);
	spread(p);
	int mid=(l(p)+r(p))/2;
	long long ans=0;
	if(l<=mid) ans+=ask(p*2,l,r);
	if(r>mid) ans+=ask(p*2+1,l,r);
	return ans%MOD;
}
void dfs1(int u,int f)
{
	dep[u]=dep[f]+1;
	fa[u]=f;
	size[u]=1;
	for(int i=head[u];i;i=edge[i].next)
	{
		int v=edge[i].to;
		if(v!=f)
		{
			dfs1(v,u);
			size[u]+=size[v];
			if(size[v]>size[son[u]])
			son[u]=v;
		}
	}
}
void dfs2(int u,int f)
{
	if(son[u])
	{
		seg[son[u]]=++cnt;
		rev[cnt]=son[u];
		top[son[u]]=top[u];
		dfs2(son[u],u);
	}
	for(int i=head[u];i;i=edge[i].next)
	{
		int v=edge[i].to;
		if(!top[v])
		{
			seg[v]=++cnt;
			rev[cnt]=v;
			top[v]=v;
			dfs2(v,u);
		}
	}
}
void get_init(int x,int y,int d)
{
	int fx=top[x],fy=top[y];
	while(fx!=fy)
	{
		if(dep[fx]<dep[fy]) 
		{
			swap(fx,fy);
			swap(x,y);
		}
		change(cnt,seg[fx],seg[fy],d);
		x=fa[x];
		fx=top[x];
	}
	if(dep[x]>dep[y])
	swap(x,y);
	change(cnt,seg[x],seg[y],d);
	return;
}
int get_sum(int x,int y)
{
	int ans=0;
	int fx=top[x],fy=top[y];
	while(fx!=fy)
	{
		if(dep[fx]<dep[fy]) 
		{
			swap(fx,fy);
			swap(x,y);
		}
		ans+=ask(cnt,seg[fx],seg[x]);
		x=fa[x];
		fx=top[x];
	}
	if(dep[x]>dep[y])
	swap(x,y);
	ans+=ask(cnt,seg[x],seg[y]);
	return ans%MOD;
}
int main()
{
	cin>>n>>m>>k>>MOD;
	for(int i=1;i<=n;i++) 
	cin>>a[i];
	for(int i=1;i<=n-1;i++)
	{
		int x,y;
		cin>>x>>y;
		addd(x,y);
		addd(y,x);
	}
	dfs1(k,k);
	cnt=seg[1]=top[1]=rev[1]=1;
	dfs2(k,k);
	build(1,1,cnt);
	for(int i=1;i<=m;i++)
	{
		int op;
		cin>>op;
		if(op==1)
		{
			int x,y,z;
			cin>>x>>y>>z;
			get_init(x,y,z);
		}
		if(op==2)
		{
			int x,y;
			cin>>x>>y;
			cout<<get_sum(x,y)<<endl;
		}
		if(op==3)
		{
			int x,y;
			cin>>x>>y;
			change(1,seg[x],seg[x]+size[x]-1,y);
		}
		if(op==4)
		{
			int x;
			cin>>x;
			cout<<ask(1,seg[x],seg[x]+size[x]-1)<<endl;
		}
	}
	return 0;
}
2019/5/15 22:52
加载中...