我直接疑惑
查看原帖
我直接疑惑
51237
Kinandra楼主2020/5/12 19:01

这样写就对了:

   int build(int l, int r) {
        if (l > r) return 0;
        int sum1 = 0, sum2 = 0;
        for (int i = l; i <= r; ++i) sum1 += lsz[frt[i]];
        for (int i = l, t; i <= r; ++i)
            if (((sum2 += lsz[t = frt[i]]) << 1) >= sum1) {
                ls[t] = build(l, i - 1);
                fa[ls[t]] = fa[rs[t] = build(i + 1, r)] = t;
                fa[0] = 0, val[t].init(t), update(t);
                return t;
            }
    }

但是这样写就错了:

   int build(int l, int r) {
        if (l > r) return 0;
        int sum1 = 0, sum2 = 0;
        for (int i = l; i <= r; ++i) sum1 += lsz[frt[i]];
        for (int i = l, t; i <= r; ++i)
            if (((sum2 += lsz[t = frt[i]]) << 1) >= sum1) {
                fa[ls[t] = build(l, i - 1)] = fa[rs[t] = build(i + 1, r)] = t;
                //***
                fa[0] = 0, val[t].init(t), update(t);
                return t;
            }
    }

第二种写法会导致***这个位置的 fa[ls[t]] != t && ls[t] , 也就是说: ls[t] 不为 00 的情况下 fa[ls[t]] 不等于 t .

???好悬学啊!有大佬能带带我吗?

完整代码:


#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
int read();

int n, m;
int w[1000006], fa[1000006], sz[1000006], lsz[1000006], sn[1000006];
int f[1000006][2], g[1000006][2];
vector<int> e[1000006];
void add(int u, int v) { e[u].push_back(v), e[v].push_back(u); }
void dfs1(int u) {
    sz[u] = 1;
    for (int v : e[u])
        if (v != fa[u]) {
            fa[v] = u, dfs1(v), sz[u] += sz[v];
            sz[sn[u]] < sz[v] ? sn[u] = v : 0;
        }
    lsz[u] = sz[u] - sz[sn[u]];
    for (int v : e[u])
        if (v != fa[u] && v != sn[u])
            g[u][0] += max(f[v][0], f[v][1]), g[u][1] += f[v][0];
    f[u][0] = max(f[sn[u]][0], f[sn[u]][1]) + g[u][0];
    g[u][1] += w[u], f[u][1] = f[sn[u]][0] + g[u][1];
}

struct Mat {
    int a[2][2];
    void init(int u) {
        a[0][0] = a[0][1] = g[u][0], a[1][0] = g[u][1], a[1][1] = -inf;
    }
    int* operator[](int p) { return a[p]; }
    Mat operator+(Mat b) {
        Mat rt;
        for (int i = 0; i < 2; ++i)
            for (int j = 0; j < 2; ++j)
                rt[i][j] = max(a[i][0] + b[0][j], a[i][1] + b[1][j]);
        return rt;
    }
    bool operator!=(Mat b) {
        return a[0][0] != b[0][0] || a[0][1] != b[0][1] || a[1][0] != b[1][0] ||
               a[1][1] != b[1][1];
    }
};

int res;
int pre[1000006], frt[1000006], top[1000006], dfn;
struct Bst {
    int fa[1000006], ls[1000006], rs[1000006];
    Mat val[1000006], s[1000006];
    bool isroot(int u) { return ls[fa[u]] != u && rs[fa[u]] != u; }
    void update(int x) { s[x] = s[ls[x]] + val[x] + s[rs[x]]; }
    int build(int l, int r) {
        if (l > r) return 0;
        int sum1 = 0, sum2 = 0;
        for (int i = l; i <= r; ++i) sum1 += lsz[frt[i]];
        for (int i = l, t; i <= r; ++i)
            if (((sum2 += lsz[t = frt[i]]) << 1) >= sum1) {
                ls[t] = build(l, i - 1);
                fa[ls[t]] = fa[rs[t] = build(i + 1, r)] = t;
                fa[0] = 0, val[t].init(t), update(t);
                return t;
            }
    }

    int solve(int u, int nw) {
        g[u][1] += nw - w[u], w[u] = nw, val[u][1][0] = g[u][1];
        while (1) {
            while (!isroot(u)) update(u), u = fa[u];
            int f0 = s[u][0][0], f1 = s[u][1][0];
            update(u);
            if (fa[u]) {
                g[fa[u]][0] += max(s[u][0][0], s[u][1][0]) - max(f0, f1);
                g[fa[u]][1] += s[u][0][0] - f0, val[fa[u]].init(fa[u]);
            } else
                return max(s[u][0][0], s[u][1][0]);
            u = fa[u];
        }
    }

    int getdep(int x) {
        if (isroot(x)) return 1;
        return getdep(fa[x]) + 1;
    }
} bst;

void dfs2(int u) {
    frt[pre[u] = ++dfn] = u;
    if (!sn[u])
        return bst.fa[bst.build(pre[top[u]], pre[u])] = fa[top[u]], void();
    top[sn[u]] = top[u], dfs2(sn[u]);
    for (int v : e[u])
        if (v != fa[u] && v != sn[u]) top[v] = v, dfs2(v);
}

int main() {
    n = read(), m = read();
    for (int i = 1; i <= n; ++i) w[i] = max(read(), 0);
    for (int i = 1; i < n; ++i) add(read(), read());
    bst.s[0][0][1] = bst.s[0][1][0] = -inf, dfs1(1), top[1] = 1, dfs2(1);
    for (int i = 1, x; i <= m; ++i)
        x = read(), res = bst.solve(x, max(read(), 0)), printf("%d\n", res);
    return 0;
}

const int _SIZE = 1 << 23;
char ibuf[_SIZE], *iS, *iT, obuf[_SIZE], *oS = obuf, *oT = oS + _SIZE - 1;
#define gc                                                        \
    (iS == iT ? (iT = (iS = ibuf) + fread(ibuf, 1, _SIZE, stdin), \
                 (iS == iT ? EOF : *iS++))                        \
              : *iS++);
int read() {
    int x = 0, f = 1;
    char c = gc;
    while (c < '0' || c > '9') f = (c == '-') ? -1 : f, c = gc;
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = gc;
    return x * f;
}

万分感谢!

2020/5/12 19:01
加载中...