蒟蒻求助 求条代码
查看原帖
蒟蒻求助 求条代码
1206518
Nervegas楼主2024/9/10 12:44
#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e6 + 10;

int n, m;
int w[maxn];

struct node{
	int to;
	int nxt;
}e[maxn], e1[maxn], e2[maxn];
//e存图,e1存每个节点作为终点的集合,e2存每个节点作为lca的集合 

int head[maxn], cnt;
int head1[maxn], cnt1;
int head2[maxn], cnt2;

void add(int u, int v){
	e[++cnt].to = v;
	e[cnt].nxt = head[u];
	head[u] = cnt;
}

void add1(int u, int v){
	e1[++cnt1].to = v;
	e1[cnt1].nxt = head1[u];
	head1[u] = cnt;
}

void add2(int u, int v){
	e2[++cnt2].to = v;
	e2[cnt2].nxt = head2[u];
	head2[u] = cnt2;
}


int f[maxn][30];
int de[maxn];
int s[maxn], t[maxn];

void dfs(int u, int fa){
	de[u] = de[fa] + 1;
	f[u][0] = fa;
	for(int i = 1 ; i <= 27 ; i++){
		f[u][i] = f[f[u][i - 1]][i - 1];
	}
	for(int i = head[u] ; i ; i = e[i].nxt){
		int v = e[i].to;
		if(v != fa){
			dfs(v, u);
		}
	}
}

int lca(int u, int v){
	if(de[u] < de[v]){
		swap(u, v);
	}
	for(int i = 27 ; i >= 0 ; i--){
		if(de[f[u][i]] >= de[v]){
			u = f[u][i];
		}
	}
	if(u == v){
		return v;
	}
	for(int i = 27 ; i >= 0 ; i--){
		if(f[u][i] != f[v][i]){
			u = f[u][i];
			v = f[v][i];
		}
	}
	return f[u][0];
}

int b1[maxn], b2[maxn], js[maxn], dist[maxn], ans[maxn];
//b1,b2是两个桶维护上行和下行的贡献,js统计以每个节点作为起点的路径个数,dist存储路径长度
 

void solve(int x){
	int t1 = b1[de[x] + w[x]];
	int t2 = b2[w[x] - de[x] + maxn];
	for(int i = head[x] ; i ; i = e[i].nxt){
		int v = e[i].to;
		if(v == f[x][0]){
			continue;
		} 
		solve(v);
	}
	b1[de[x]] += js[x];//上行产生贡献;
	for(int i = head1[x] ; i ; i = e1[i].nxt){
		int v = e1[i].to;
		b2[dist[v] - de[t[v]] + maxn]++;
	} 
	ans[x] += b1[w[x] + de[x]] - t1 + b2[w[x] - de[x] + maxn] - t2;
	for(int i = head2[x] ; i ; i = e2[i].nxt){
		int v = e2[i].to;
		b1[de[s[v]]]--;
		b2[dist[v] - de[t[v]] + maxn]--;
	}
}

int main(){
	ios::sync_with_stdio(0),cin.tie(0);
	cin >> n >> m;
	for(int i = 1 ; i < n ; i++){
		int u, v;
		cin >> u >> v;
		add(u, v);
		add(v, u);
	}
	for(int i = 1 ; i <= n ; i++){
		cin >> w[i];
	}
	dfs(1, 0);
	for(int i = 1 ; i <= m ; i++){
		cin >> s[i] >> t[i];
		int l = lca(s[i], t[i]);
		dist[i] = de[s[i]] + de[t[i]] - 2 * de[l];
		js[s[i]]++;
		add1(t[i], i);
		add2(l, i);
		if(de[l] + w[l] == de[s[i]]){
			ans[l]--;
		}
	}
	solve(1);
	for(int i = 1 ; i <= n ; i++){
		cout << ans[i] << " ";
	}
	cout << endl;
	return 0;
}
2024/9/10 12:44
加载中...