样例不过 全WA 找不到哪里错了 求助dalao
#include<bits/stdc++.h>
using namespace std;
const int maxn = 6e6 + 10;
int n, m;
int w[maxn];
struct node{
int to;
int nxt;
}e[maxn], e1[maxn], e2[maxn];
//e存图,e1存每个节点作为终点的集合,e2存每个节点作为lca的集合
int head[maxn], cnt;
int head1[maxn], cnt1;
int head2[maxn], cnt2;
void add(int u, int v){
e[++cnt].to = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
void add1(int u, int v){
e1[++cnt1].to = v;
e1[cnt1].nxt = head1[u];
head1[u] = cnt;
}
void add2(int u, int v){
e2[++cnt2].to = v;
e2[cnt2].nxt = head2[u];
head2[u] = cnt2;
}
int f[maxn][30];
int de[maxn];
int s[maxn], t[maxn];
void dfs(int u, int fa){
de[u] = de[fa] + 1;
f[u][0] = fa;
for(int i = 1 ; i <= 20 ; i++){
f[u][i] = f[f[u][i - 1]][i - 1];
}
for(int i = head[u] ; i ; i = e[i].nxt){
int v = e[i].to;
if(v != fa){
dfs(v, u);
}
}
}
int get_lca(int u, int v){
if(de[u] < de[v]){
swap(u, v);
}
for(int i = 20 ; i >= 0 ; i--){
if(de[f[u][i]] >= de[v]){
u = f[u][i];
}
}
if(u == v){
return v;
}
for(int i = 20 ; i >= 0 ; i--){
if(f[u][i] != f[v][i]){
u = f[u][i];
v = f[v][i];
}
}
return f[u][0];
}
int b1[maxn], b2[maxn], js[maxn], dist[maxn], ans[maxn];
//b1,b2是两个桶维护上行和下行的贡献,js统计以每个节点作为起点的路径个数,dist存储路径长度
void solve(int x){
int t1 = b1[de[x] + w[x]];
int t2 = b2[w[x] - de[x] + maxn];
for(int i = head[x] ; i ; i = e[i].nxt){
int v = e[i].to;
if(v == f[x][0]){
continue;
}
solve(v);
}
b1[de[x]] += js[x];//上行产生贡献;
for(int i = head1[x] ; i ; i = e1[i].nxt){
int v = e1[i].to;
b2[dist[v] - de[t[v]] + maxn]++;
}
ans[x] += b1[w[x] + de[x]] - t1 + b2[w[x] - de[x] + maxn] - t2;
for(int i = head2[x] ; i ; i = e2[i].nxt){
int v = e2[i].to;
b1[de[s[v]]]--;
b2[dist[v] - de[t[v]] + maxn]--;
}
}
int main(){
ios::sync_with_stdio(0),cin.tie(0);
cin >> n >> m;
for(int i = 1 ; i < n ; i++){
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
dfs(1, 0);
for(int i = 1 ; i <= n ; i++){
cin >> w[i];
}
for(int i = 1 ; i <= m ; i++){
cin >> s[i] >> t[i];
int lca = get_lca(s[i], t[i]);
dist[i] = de[s[i]] + de[t[i]] - 2 * de[lca];
js[s[i]]++;
add1(t[i], i);
add2(lca, i);
if(de[lca] + w[lca] == de[s[i]]){
ans[lca]--;
}
}
solve(1);
for(int i = 1 ; i <= n ; i++){
cout << ans[i] << " ";
}
cout << endl;
return 0;
}