把线段树换到模板题上正确,所以应该不是线段树的锅。
在错误数据下各组树上信息与从前写的满分代码相同,所以应该不是 dfs 的锅。
求助
#include"iostream"
#include"cstdio"
#include"cmath"
using namespace std;
#define MAXN 100005
#define read(x) scanf("%d",&x)
#define ll long long
int n,m;
int u,v;
int ty,x,y,z;
int p,root;
struct node
{
int to,nxt;
}e[MAXN<<1];
int head[MAXN],cnt=0;
int tot[MAXN],f[MAXN],top[MAXN],dep[MAXN];
int son[MAXN],id[MAXN],rt=0;
struct Tree
{
int l,r;
ll sum,lazy;
}a[MAXN<<2];
int b[MAXN],t[MAXN];
void add(int u,int v){e[++cnt].to=v,e[cnt].nxt=head[u],head[u]=cnt;}
int dfs1(int cur,int fa)
{
int maxn=0;
tot[cur]=1,f[cur]=fa,dep[cur]=dep[f[cur]]+1;
for(int i=head[cur];i;i=e[i].nxt)
{
int j=e[i].to;
if(j==fa) continue;
int op=dfs1(j,cur);
if(op>maxn) son[cur]=j,maxn=op;
tot[cur]+=op;
}
return tot[cur];
}
void dfs2(int cur,int topf)
{
top[cur]=topf,id[cur]=++rt;
if(son[cur]) dfs2(son[cur],topf);
for(int i=head[cur];i;i=e[i].nxt)
{
int j=e[i].to;
if(j==f[cur]||j==son[cur]) continue;
dfs2(j,j);
}
return;
}
inline void update(int k){a[k].sum=a[k<<1].sum+a[k<<1|1].sum;}
void build(int k,int l,int r)
{
a[k].l=l,a[k].r=r;
if(l==r){a[k].sum=1ll*t[l];return;}
int mid=(l+r)>>1;
build(k<<1,l,mid),build(k<<1|1,mid+1,r);
update(k);
}
void lazydown(int k)
{
if(a[k].l==a[k].r){a[k].lazy=0;return;}
a[k<<1].sum=(a[k<<1].sum+1ll*(a[k<<1].r-a[k<<1].l+1)%p*a[k].lazy%p)%p;
a[k<<1|1].sum=(a[k<<1|1].sum+1ll*(a[k<<1|1].r-a[k<<1|1].l+1)%p*a[k].lazy%p)%p;
a[k<<1].lazy=(a[k<<1].lazy+a[k].lazy)%p;
a[k<<1|1].lazy=(a[k<<1|1].lazy+a[k].lazy)%p;
a[k].lazy=0;
}
void modify(int k,int l,int r,int x)
{
if(a[k].l==l&&a[k].r==r)
{
a[k].sum=(a[k].sum+1ll*(a[k].r-a[k].l+1)%p*(ll)x%p)%p;
a[k].lazy=(a[k].lazy+(ll)x)%p;
return;
}
if(a[k].lazy) lazydown(k);
int mid=a[k<<1].r;
if(r<=mid) modify(k<<1,l,r,x);
else if(l>mid) modify(k<<1|1,l,r,x);
else modify(k<<1,l,mid,x),modify(k<<1|1,mid+1,r,x);
update(k);
}
int query(int k,int l,int r)
{
if(a[k].l==l&&a[k].r==r) return a[k].sum%p;
if(a[k].lazy) lazydown(k);
int mid=a[k<<1].r;
if(r<=mid) return query(k<<1,l,r);
else if(l>mid) return query(k<<1|1,l,r);
else return (query(k<<1,l,mid)+query(k<<1|1,mid+1,r))%p;
}
void mod_chain(int l,int r,int x)
{
while(top[l]!=top[r])
{
if(dep[top[l]]<dep[top[r]]) swap(l,r);
modify(1,id[top[l]],id[l],x);
l=f[top[l]];
}
if(dep[l]>dep[r]) swap(l,r);
modify(1,id[l],id[r],x);
}
int que_chain(int l,int r)
{
ll ans=0;
while(top[l]!=top[r])
{
if(dep[top[l]]<dep[top[r]]) swap(l,r);
ans=(ans+(ll)query(1,id[top[l]],id[l]))%p;
l=f[top[l]];
}
if(dep[l]>dep[r]) swap(l,r);
ans=(ans+(ll)query(1,id[l],id[r]))%p;
return ans%p;
}
void mod_son(int l,int y){modify(1,id[l],id[l]+tot[l]-1,x);}
int que_son(int l){return query(1,id[l],id[l]+tot[l]-1)%p;}
int main()
{
read(n),read(m),read(root),read(p);
for(int i=1;i<=n;i++) read(b[i]);
for(int i=1;i<n;i++) read(u),read(v),add(u,v),add(v,u);
dfs1(root,0);
dfs2(root,root);
for(int i=1;i<=n;i++) t[id[i]]=b[i];
build(1,1,n);
for(int i=1;i<=m;i++)
{
read(ty);
if(ty==1) read(x),read(y),read(z),mod_chain(x,y,z%p);
else if(ty==2) read(x),read(y),printf("%d\n",que_chain(x,y));
else if(ty==3) read(x),read(y),mod_son(x,y%p);
else read(x),printf("%d\n",que_son(x));
}
return 0;
}