萌新求助,树剖wa了八个点,看了半天没看出来
#include<bits/stdc++.h>
#define SIZE 300000
#define lson (p<<1)
#define rson (p<<1|1)
#define mid (l+r>>1)
using namespace std;
int n,m;
int Head[SIZE],Next[SIZE],Ver[SIZE];
int fa[SIZE],Deep[SIZE],Size[SIZE],Son[SIZE];
int tot=0;
int Top[SIZE];
int cnt;
int id[SIZE];
int wt[SIZE],w[SIZE];
int Max[SIZE],Sum[SIZE];
int res,ves;
void Add(int x,int y){
Ver[++tot]=y;
Next[tot]=Head[x];
Head[x]=tot;
}
void dfs1(int x,int f,int deep){
Deep[x]=deep;
fa[x]=f;
Size[x]=1;
int maxson=-1;
for(int i=Head[x];i;i=Next[i]){
int y=Ver[i];
if(y==f) continue;
dfs1(y,x,deep+1);
Size[x]+=Size[y];
if(Size[y]>maxson) maxson=Size[y],Son[x]=y;
}
}
void dfs2(int x,int topf){
id[x]=++cnt;
w[cnt]=wt[x];
Top[x]=topf;
if(!Son[x]) return;
dfs2(Son[x],topf);
for(int i=Head[x];i;i=Next[i]){
int y=Ver[i];
if(y==fa[x]||y==Son[x])continue;
dfs2(y,y);
}
}
void build(int p,int l,int r){
if(l==r){
Sum[p]=w[l];
Max[p]=w[l];
return;
}
build(lson,l,mid);
build(rson,mid+1,r);
Sum[p]=Sum[lson]+Sum[rson];
Max[p]=max(Max[lson],Max[rson]);
}
void change(int p,int l,int r,int Val,int pos){
if(pos>r||pos<l) return;
if(l==r&&r==pos){
Sum[p]=Val;
Max[p]=Val;
return;
}
if(mid>=pos) change(lson,l,mid,Val,pos);
if(pos>=mid+1) change(rson,mid+1,r,Val,pos);
Sum[p]=Sum[lson]+Sum[rson];
Max[p]=max(Max[lson],Max[rson]);
}
void query(int p,int l,int r,int L,int R){
if(L<=l&&r<=R){
res+=Sum[p];
ves=max(ves,Max[p]);
return;
}
if(L<=mid) query(lson,l,mid,L,R);
if(R>mid) query(rson,mid+1,r,L,R);
}
int QSUM(int x,int y){
int ans=0;
while(Top[x]!=Top[y]){
res=0;
if(Deep[Top[x]]<Deep[Top[y]]) swap(x,y);
query(1,1,n,id[Top[x]],id[x]);
ans+=res;
x=fa[Top[x]];
}
if(Deep[x]>Deep[y]) swap(x,y);
res=0;
query(1,1,n,id[x],id[y]);
ans+=res;
return ans;
}
int QMAX(int x,int y){
int ans=-300000;
while(Top[x]!=Top[y]){
if(Deep[Top[x]]<Deep[Top[y]]) swap(x,y);
ves=-300000;
query(1,1,n,id[Top[x]],id[x]);
ans=max(ves,ans);
x=fa[Top[x]];
}
if(Deep[x]>Deep[y]) swap(x,y);
ves=-300000;
query(1,1,n,id[x],id[y]);
ans=max(ans,ves);
return ans;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n-1;i++){
int a,b;
scanf("%d%d",&a,&b);
Add(a,b);
Add(b,a);
}
for(int i=1;i<=n;i++) scanf("%d",&wt[i]);
scanf("%d",&m);
dfs1(1,0,1);
dfs2(1,1);
build(1,1,n);
while(m--){
string order;
int x,y;
cin>>order;
scanf("%d%d",&x,&y);
if(order=="CHANGE") change(1,1,n,y,x);
if(order=="QMAX") printf("%d\n",QMAX(x,y));
if(order=="QSUM") printf("%d\n",QSUM(x,y));
}
return 0;
}