RT,已发现问题在 md
数组(md[u][i]
表示从 u 节点向上 2i 条边的最小值,即 min_dis
),但就是不知道错在哪里
#import "iostream"
#import "vector"
#import "algorithm"
#define For(a, b, c) for (int a = b, a##END = c; a <= a##END; ++ a)
#define Fol(a, b, c) for (int a = b, a##END = c; a >= a##END; -- a)
#define id first
#define w second
using namespace std;
const int N = 2.5e5 + 7;
int n, m, k, h[N];
vector<pair<int, int> > g[N], g1[N];
int tmp[N << 1], tot;
int dep[N], dfn[N], fa[N][20], cnt, md[N][20];
void dfs1(int u, int f) {
dep[u] = dep[f] + 1;
fa[u][0] = f; dfn[u] = ++ cnt;
for (auto son : g[u]) if (son.id != f)
md[son.id][0] = son.w, dfs1(son.id, u);
}
int dp[N]; bool key[N];
void dfs2(int u) {
dp[u] = 0;
for (auto son : g1[u]) {
dfs2(son.id);
dp[u] += key[son.id] ? son.w : min(dp[son.id], son.w);
}
}
void init_ST() {
For (i, 1, 19)
For (j, 1, n)
fa[j][i] = fa[fa[j][i - 1]][i - 1],
md[j][i] = min(md[j][i - 1], md[fa[j][i - 1]][i - 1]);
}
pair<int, int> lca_dis(int x, int y) {
int ret = 2e9;
if (dep[x] < dep[y]) x ^= y ^= x ^= y;
Fol (i, 19, 0)
if (dep[fa[x][i]] >= dep[y])
ret = min(ret, md[x][i]), x = fa[x][i];
if (x == y) return {x, ret};
Fol (i, 19, 0)
if (fa[x][i] != fa[y][i])
ret = min({ret, md[x][i], md[y][i]}),
x = fa[x][i], y = fa[y][i];
return {fa[x][0], ret};
}
main() {
__builtin_memset(md, 0x3f, sizeof md);
// ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
cin >> n;
For (i, 2, n) {
int u, v, w; cin >> u >> v >> w;
g[u].emplace_back(v, w),
g[v].emplace_back(u, w);
}
dfs1(1, 0); init_ST();
cin >> m;
For (i, 1, m) {
cin >> k, ++ k;
h[1] = 1;
For (j, 2, k) cin >> h[j], key[h[j]] = 1;
// build:
tot = 0;
sort(h + 1, h + 1 + k, [](int x, int y) { return dfn[x] < dfn[y]; });
For (i, 1, k - 1)
tmp[++ tot] = h[i], tmp[++ tot] = lca_dis(h[i], h[i + 1]).id;
tmp[++ tot] = h[k];
sort(tmp + 1, tmp + 1 + tot, [](int x, int y) { return dfn[x] < dfn[y]; });
tot = unique(tmp + 1, tmp + 1 + tot) - tmp - 1;
For (i, 1, tot - 1) {
const auto lca_ = lca_dis(tmp[i], tmp[i + 1]);
g1[lca_.id].emplace_back(tmp[i + 1], lca_.w);
cerr << "(" << lca_.id << ", " << tmp[i + 1] << ") = " << lca_.w << "\n";
}
// test:
for (int i = 1; i <= n; ++ i) <%
cerr << i << ": ";
for (auto x : g1[i])
cerr << "(" << x.id << ", " << x.w << "), ";
cerr << "\n";
%>
// solve:
dfs2(1);
cout << dp[1] << "\n";
// clear:
For (j, 1, k) key[h[j]] = 0;
For (j, 1, n) g1[j].clear();
}
}
/*
for node u, son v:
v \in KeyNodes: dp[u] += min(dp[v], dis(u, v))
dp[u] += dis(u, v)
*/