这样写就对了:
int build(int l, int r) {
if (l > r) return 0;
int sum1 = 0, sum2 = 0;
for (int i = l; i <= r; ++i) sum1 += lsz[frt[i]];
for (int i = l, t; i <= r; ++i)
if (((sum2 += lsz[t = frt[i]]) << 1) >= sum1) {
ls[t] = build(l, i - 1);
fa[ls[t]] = fa[rs[t] = build(i + 1, r)] = t;
fa[0] = 0, val[t].init(t), update(t);
return t;
}
}
但是这样写就错了:
int build(int l, int r) {
if (l > r) return 0;
int sum1 = 0, sum2 = 0;
for (int i = l; i <= r; ++i) sum1 += lsz[frt[i]];
for (int i = l, t; i <= r; ++i)
if (((sum2 += lsz[t = frt[i]]) << 1) >= sum1) {
fa[ls[t] = build(l, i - 1)] = fa[rs[t] = build(i + 1, r)] = t;
//***
fa[0] = 0, val[t].init(t), update(t);
return t;
}
}
第二种写法会导致***
这个位置的 fa[ls[t]] != t && ls[t]
, 也就是说: ls[t]
不为 0 的情况下 fa[ls[t]]
不等于 t
.
???好悬学啊!有大佬能带带我吗?
完整代码:
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
int read();
int n, m;
int w[1000006], fa[1000006], sz[1000006], lsz[1000006], sn[1000006];
int f[1000006][2], g[1000006][2];
vector<int> e[1000006];
void add(int u, int v) { e[u].push_back(v), e[v].push_back(u); }
void dfs1(int u) {
sz[u] = 1;
for (int v : e[u])
if (v != fa[u]) {
fa[v] = u, dfs1(v), sz[u] += sz[v];
sz[sn[u]] < sz[v] ? sn[u] = v : 0;
}
lsz[u] = sz[u] - sz[sn[u]];
for (int v : e[u])
if (v != fa[u] && v != sn[u])
g[u][0] += max(f[v][0], f[v][1]), g[u][1] += f[v][0];
f[u][0] = max(f[sn[u]][0], f[sn[u]][1]) + g[u][0];
g[u][1] += w[u], f[u][1] = f[sn[u]][0] + g[u][1];
}
struct Mat {
int a[2][2];
void init(int u) {
a[0][0] = a[0][1] = g[u][0], a[1][0] = g[u][1], a[1][1] = -inf;
}
int* operator[](int p) { return a[p]; }
Mat operator+(Mat b) {
Mat rt;
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
rt[i][j] = max(a[i][0] + b[0][j], a[i][1] + b[1][j]);
return rt;
}
bool operator!=(Mat b) {
return a[0][0] != b[0][0] || a[0][1] != b[0][1] || a[1][0] != b[1][0] ||
a[1][1] != b[1][1];
}
};
int res;
int pre[1000006], frt[1000006], top[1000006], dfn;
struct Bst {
int fa[1000006], ls[1000006], rs[1000006];
Mat val[1000006], s[1000006];
bool isroot(int u) { return ls[fa[u]] != u && rs[fa[u]] != u; }
void update(int x) { s[x] = s[ls[x]] + val[x] + s[rs[x]]; }
int build(int l, int r) {
if (l > r) return 0;
int sum1 = 0, sum2 = 0;
for (int i = l; i <= r; ++i) sum1 += lsz[frt[i]];
for (int i = l, t; i <= r; ++i)
if (((sum2 += lsz[t = frt[i]]) << 1) >= sum1) {
ls[t] = build(l, i - 1);
fa[ls[t]] = fa[rs[t] = build(i + 1, r)] = t;
fa[0] = 0, val[t].init(t), update(t);
return t;
}
}
int solve(int u, int nw) {
g[u][1] += nw - w[u], w[u] = nw, val[u][1][0] = g[u][1];
while (1) {
while (!isroot(u)) update(u), u = fa[u];
int f0 = s[u][0][0], f1 = s[u][1][0];
update(u);
if (fa[u]) {
g[fa[u]][0] += max(s[u][0][0], s[u][1][0]) - max(f0, f1);
g[fa[u]][1] += s[u][0][0] - f0, val[fa[u]].init(fa[u]);
} else
return max(s[u][0][0], s[u][1][0]);
u = fa[u];
}
}
int getdep(int x) {
if (isroot(x)) return 1;
return getdep(fa[x]) + 1;
}
} bst;
void dfs2(int u) {
frt[pre[u] = ++dfn] = u;
if (!sn[u])
return bst.fa[bst.build(pre[top[u]], pre[u])] = fa[top[u]], void();
top[sn[u]] = top[u], dfs2(sn[u]);
for (int v : e[u])
if (v != fa[u] && v != sn[u]) top[v] = v, dfs2(v);
}
int main() {
n = read(), m = read();
for (int i = 1; i <= n; ++i) w[i] = max(read(), 0);
for (int i = 1; i < n; ++i) add(read(), read());
bst.s[0][0][1] = bst.s[0][1][0] = -inf, dfs1(1), top[1] = 1, dfs2(1);
for (int i = 1, x; i <= m; ++i)
x = read(), res = bst.solve(x, max(read(), 0)), printf("%d\n", res);
return 0;
}
const int _SIZE = 1 << 23;
char ibuf[_SIZE], *iS, *iT, obuf[_SIZE], *oS = obuf, *oT = oS + _SIZE - 1;
#define gc \
(iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, _SIZE, stdin), \
(iS == iT ? EOF : *iS++)) \
: *iS++);
int read() {
int x = 0, f = 1;
char c = gc;
while (c < '0' || c > '9') f = (c == '-') ? -1 : f, c = gc;
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = gc;
return x * f;
}
万分感谢!