dfs部分检查过应该没有问题
树状数组区间改区间查,树剖都是第一次写
问题应该出在这两部分
谢谢大佬
#include <iostream>
#include <cstdio>
using namespace std;
const int N=3000000;
int n,m,r,p;
int w[N];
struct edge
{
int to,next;
}e[N];
int cnt,head[N];
void add(int x,int y)
{
cnt++;
e[cnt].to=y;
e[cnt].next=head[x];
head[x]=cnt;
}
///树状数组部分
int c1[N],c2[N];
int updata(int x,int k)
{
int xx=x;
while(x<=n)
{
c1[x]+=k;
c2[x]+=k*(xx-1);
x+=x&(-x);
}
}
int getsum(int x)
{
int ans=0,xx=x;
while(x)
{
ans+=xx*c1[x]-c2[x];
x-=x&(-x);
}
return ans;
}
int up(int x,int y,int k)
{
if(x>y)swap(x,y);
updata(x,k);
updata(y+1,-k);
}
int get(int x,int y)
{
if(x>y)swap(x,y);
return getsum(y) - getsum(x-1);
}
///DFS部分
int son[N],siz[N],top[N],dep[N],fa[N],dfn[N];
int dfss=0;
void dfs1(int u,int f,int de)
{
dep[u]=de;
fa[u]=f;
siz[u]=1;
int pigson=-1;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==f)continue;
dfs1(v,u,de+1);
siz[u]+=siz[v];
if(siz[v]>pigson)
{
son[u]=v;
pigson=siz[v];
}
}
}
void dfs2(int u,int tf)
{
dfn[u]=++dfss;
c1[dfss]=w[u];
top[u]=tf;
if(son[u]==0)return ;
dfs2(son[u],tf);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(dfn[v]==0)
{
dfs2(v,v);
}
}
}
///树剖部分
int line_add(int u,int v,int k)
{
k%=p;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
{
swap(u,v);
}
up(dfn[top[u]],dfn[u],k);
u=fa[top[u]];
}
if(dep[u]>dep[v])
{
swap(u,v);
}
up(dfn[u],dfn[v],k);
}
int line_find(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]])
{
swap(u,v);
}
ans=(ans+get(dfn[top[u]],dfn[u]))%p;
u=fa[top[u]];
}
if(dep[u]>dep[v])
{
swap(u,v);
}
ans=(ans+get(dfn[u],dfn[v]))%p;
return ans;
}
int tree_find(int u)
{
return get(dfn[u],dfn[u]+siz[u]-1);
}
void tree_add(int u,int k)
{
k%=p;
up(dfn[u],dfn[u]+siz[u]-1,k);
}
///main部分
int main()
{
int i,j,k;
cin>>n>>m>>r>>p;
for(i=1;i<=n;i++)
{
cin>>w[i];
}
for(i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
add(x,y);
add(y,x);
}
dfs1(r,0,1);
/*
for(i=1;i<=n;i++)
{
printf("siz%d son%d fa%d dep%d\n",siz[i],son[i],fa[i],dep[i]);
}
*/
dfs2(r,r);
/*
for(i=1;i<=n;i++)
{
printf("top%d dfn%d\n",top[i],dfn[i]);
}
*/
for(i=1;i<=m;i++)
{
int b;
cin>>b;
if(b==1)
{
int x,y,z;
cin>>x>>y>>z;
line_add(x,y,z);
}
if(b==2)
{
int x,y;
cin>>x>>y;
cout<<line_find(x,y)<<endl;
}
if(b==3)
{
int x,y;
cin>>x>>y;
tree_add(x,y);
}
if(b==4)
{
int x;
cin>>x;
cout<<tree_find(x)<<endl;
}
}
return 0;
}