已知问题可能出现在时间更新上
含有调试
#include<bits/stdc++.h>
#define N 400400
using namespace std;
int n,m,qw,Qt;
vector<int> R[N];
int v[N],w[N],c[N];
int s[N],e[N],t[N],st[N][22],dep[N],cnt=0;
struct point {
int l,r,t,id,lca;
}q[N],p[N];int pl=0,ql=0;
bool cmp(point aa,point bb){
if(aa.l/Qt!=bb.l/Qt) return aa.l<bb.l;
if(aa.r/Qt!=bb.r/Qt) return aa.r<bb.r;
return aa.t<bb.t;
}
int hap[N];
int used[N];
long long answer=0,ans[N];
void add(int pos){
pos=t[pos];
if(used[pos]==0){
hap[c[pos]]++;
answer+=(long long)w[hap[c[pos]]]*v[c[pos]];
}
else {
answer-=(long long)w[hap[c[pos]]]*v[c[pos]];
hap[c[pos]]--;
}
used[pos]++;
}
void del(int pos){
pos=t[pos];
if(used[pos]==1){
answer-=(long long)w[hap[c[pos]]]*v[c[pos]];
hap[c[pos]]--;
}
else {
hap[c[pos]]++;
answer+=(long long)w[hap[c[pos]]]*v[c[pos]];
}
used[pos]--;
}
void upd(int pos,int tm) {
if(used[p[tm].l]==1) {
del(p[tm].l);
swap(c[p[tm].l],p[tm].r);
add(p[tm].l);
}
else swap(c[p[tm].l],p[tm].r);
}
void dfs(int u,int f){
st[u][0]=f,dep[u]=dep[f]+1;
for(int i=1;i<=20;i++){
int p=st[u][i-1];
if(!p) break;
st[u][i]=st[p][i-1];
}
s[u]=++cnt;
for(int i=0;i<R[u].size();i++){
int to=R[u][i];
if(to==f) continue;
dfs(to,u);
}
e[u]=++cnt;
t[e[u]]=t[s[u]]=u;
}
int LCA(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
for(int i=20;i>=0;i--) if(dep[u]-(1<<i)>=dep[v]) u=st[u][i];
if(u==v) return u;
for(int i=20;i>=0;i--) if(st[u][i]&&st[v][i]&&st[u][i]!=st[v][i])
u=st[u][i],v=st[v][i];
return st[u][0];
}
int main(){
cin>>n>>m>>qw;
Qt=pow(n,0.666);
for(int i=1;i<=m;i++) scanf("%d",&v[i]);
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
R[u].push_back(v),R[v].push_back(u);
}
for(int i=1;i<=n;i++) scanf("%d",&c[i]);
dfs(1,cnt=0);
for(int i=1;i<=qw;i++){
int ty,u,v;
scanf("%d%d%d",&ty,&u,&v);
if(ty==0) {
p[++pl]=(point){u,v,0,0,0};
} else {
q[++ql]=(point){u,v,pl,ql,0};
}
}
for(int i=1;i<=ql;i++){
q[i].lca=LCA(q[i].l,q[i].r);
// printf("%d %d => %d \n",q[i].l,q[i].r,q[i].lca);
int u=q[i].l,v=q[i].r;
if(s[u]>s[v]) swap(u,v);
u=e[u],v=s[v];
q[i].l=u,q[i].r=v;
}
sort(q+1,q+1+ql,cmp);
int l=1,r=0,tm=0;
for(int i=1;i<=ql;i++)
{
printf("### %d(%d) %d(%d) %d ###\n",q[i].l,t[q[i].l],q[i].r,t[q[i].r],q[i].t);
while(l>q[i].l) add(--l);printf("ans:%d \n",answer);
while(r<q[i].r) add(++r);printf("ans:%d \n",answer);
while(l<q[i].l) del(l++);printf("ans:%d \n",answer);
while(r>q[i].r) del(r--);printf("ans:%d \n",answer);
while(tm>q[i].t) upd(i,tm--);printf("ans:%d \n",answer);
while(r<q[i].t) upd(i,++tm);printf("ans:%d \n",answer);
add(s[q[i].lca]);
ans[q[i].id]=answer;
printf("ANSWER:%d \n",answer);
del(s[q[i].lca]);
puts("****");
}
for(int i=1;i<=n;i++) printf("%d\n",ans[i]);
return 0;
}