65分求助
查看原帖
65分求助
235696
muvum楼主2022/1/27 15:10
#include <vector>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

const int MAXN = 3e5;

struct Edge {
    int to, nxt;
};

int n, m, tot;
Edge e[MAXN<<1];
// dis[i] = dep[i]-2*dep[LCA(s[i],t[i])]
int s[MAXN], w[MAXN], dis[MAXN], ans[MAXN];
// end[i]表示以i号节点为终点的玩家编号,lca[i]表示以节点i为起点和终点的LCA的玩家编号
std::vector<int> end[MAXN], lca[MAXN];
// cnts[i]是以节点i为起点的玩家数量, b1、b2分别是记录起点和终点贡献的桶
int head[MAXN], cnts[MAXN], b1[MAXN<<1], b2[MAXN<<1], dep[MAXN], fa[MAXN][20];

inline void addedge(int u, int v) {
    e[++tot].nxt = head[u];
    e[tot].to = v;
    head[u] = tot;
}

// 建树
void dfs1(int u) {
    for (int i=1; (1<<i)<=dep[u]; ++i)
        fa[u][i] = fa[fa[u][i-1]][i-1];

    for (int i=head[u]; i; i=e[i].nxt) {
        int v = e[i].to;
        if (v == fa[u][0]) continue;
        dep[v] = dep[u] + 1;
        fa[v][0] = u; dfs1(v);
    }
}

inline int LCA(int x, int y) {
    if (dep[x] < dep[y]) std::swap(x, y);
    int k = 0, dif = dep[x] - dep[y];
    while (dif > 0) {
        if (dif & 1) x = fa[x][k];
        k++; dif >>= 1;
    }
    if (x == y) return x;

    for (int i=18; i>=0; --i)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

void dfs2(int u) {
    int t1 = b1[w[u]+dep[u]], t2 = b2[w[u]-dep[u]+MAXN];
    for (int i=head[u]; i; i=e[i].nxt) {
        int v = e[i].to;
        if (v == fa[u][0]) continue;
        dfs2(v);
    }

    // 将以u为起点的贡献装入桶中
    b1[dep[u]] += cnts[u];
    // 将以u为终点的贡献装入桶中
    for (auto i : end[u])
        b2[dis[i]+MAXN]++;

    if (w[u] == 0) ans[u] = cnts[u];
    else ans[u] = b1[w[u]+dep[u]] - t1 + b2[w[u]-dep[u]+MAXN] - t2;

    // 若当前节点是某玩家路径的“拐点”,则再往上不受此玩家影响,删掉贡献
    for (auto i : lca[u])
        b1[dep[s[i]]]--, b2[dis[i]+MAXN]--;
}

int main(void) {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr); std::cout.tie(nullptr);

    std::cin >> n >> m;
    for (int i=1; i<n; ++i) {
        int u, v; std::cin >> u >> v;
        addedge(u, v); addedge(v, u);
    }
    for (int i=1; i<=n; ++i) std::cin >> w[i];
    dfs1(1);
    for (int i=1; i<=m; ++i) {
        int t; std::cin >> s[i] >> t;
        cnts[s[i]]++;
        int c = LCA(s[i], t);
        end[t].push_back(i);
        lca[c].push_back(i);
        dis[i] = dep[s[i]] - 2*dep[c];
    }

    dfs2(1);

    for (int i=1; i<=n; ++i)
        std::cout << ans[i] << ' ';

    return 0;
}
2022/1/27 15:10
加载中...