普普通通线段树也会越界吗?
查看原帖
普普通通线段树也会越界吗?
261262
WaltVBAlston楼主2021/10/3 19:44

RT,学校OJ测这道题9分,洛谷0分,教练说是访问越界了,我不理解,为什么?

#include <iostream>
#include <iomanip>
#include <cstdio>
#define MAXN 1000005
using namespace std;
struct node {
    int l, r;
    double xy, x, y, x2, tag_x, tag_y;
} tree[4 * MAXN];
int n, m;
double x[MAXN], y[MAXN];
inline void pushdown(int i) {
    if (tree[i].tag_x || tree[i].tag_y) {
        tree[i * 2].x2 += 2 * tree[i].tag_x * tree[i * 2].x +
                          (tree[i * 2].r - tree[i * 2].r + 1) * tree[i].tag_x * tree[i].tag_x;
        tree[i * 2].xy += tree[i].tag_x * tree[i * 2].y + tree[i].tag_y * tree[i * 2].x +
                          (tree[i * 2].r - tree[i * 2].r + 1) * tree[i].tag_x * tree[i].tag_y;
        tree[i * 2].x += (tree[i * 2].r - tree[i * 2].r + 1) * tree[i].tag_x;
        tree[i * 2].y += (tree[i * 2].r - tree[i * 2].r + 1) * tree[i].tag_y;
        tree[i * 2 + 1].x2 += 2 * tree[i].tag_x * tree[i * 2 + 1].x +
                              (tree[i * 2 + 1].r - tree[i * 2 + 1].r + 1) * tree[i].tag_x * tree[i].tag_x;
        tree[i * 2 + 1].xy += tree[i].tag_x * tree[i * 2 + 1].y + tree[i].tag_y * tree[i * 2 + 1].x +
                              (tree[i * 2 + 1].r - tree[i * 2 + 1].r + 1) * tree[i].tag_x * tree[i].tag_y;
        tree[i * 2 + 1].x += (tree[i * 2 + 1].r - tree[i * 2 + 1].r + 1) * tree[i].tag_x;
        tree[i * 2 + 1].y += (tree[i * 2 + 1].r - tree[i * 2 + 1].r + 1) * tree[i].tag_y;
        tree[i * 2].tag_x += tree[i].tag_x, tree[i * 2].tag_y += tree[i].tag_y;
        tree[i * 2 + 1].tag_x += tree[i].tag_x, tree[i * 2 + 1].tag_y += tree[i].tag_y;
        tree[i].tag_x = 0, tree[i].tag_y = 0;
    }
    return;
}
inline void pushup(int i) {
    tree[i].x = tree[i * 2].x + tree[i * 2 + 1].x;
    tree[i].xy = tree[i * 2].xy + tree[i * 2 + 1].xy;
    tree[i].x2 = tree[i * 2].x2 + tree[i * 2 + 1].x2;
    tree[i].y = tree[i * 2].y + tree[i * 2 + 1].y;
    return;
}
inline void build(int i, int l, int r) {
    tree[i].l = l, tree[i].r = r, tree[i].tag_x = 0, tree[i].tag_y = 0;
    if (l == r) {
        tree[i].xy = x[l] * y[l], tree[i].x = x[l], tree[i].y = y[l], tree[i].x2 = x[l] * x[l];
        return;
    }
    int mid = (l + r) >> 1;
    build(i * 2, l, mid), build(i * 2 + 1, mid + 1, r);
    pushup(i);
    return;
}
inline double get_a(int i, int l, int r) {
    if (tree[i].l >= l && tree[i].r <= r) {
        double a = tree[i].xy, b = tree[i].r - tree[i].l + 1.0, c = tree[i].x2, d = tree[i].x, e = tree[i].y;
        return (a - e * d / b) / (c - d * d / b);
    }
    int mid = (tree[i].l + tree[i].r) >> 1;
    pushdown(i);
    double res = 0;
    if (mid >= l)
        res += get_a(i * 2, l, r);
    if (mid + 1 <= r)
        res += get_a(i * 2 + 1, l, r);
    return res;
}
inline void update(int i, int l, int r, int x, int y) {
    if (tree[i].l >= l && tree[i].r <= r) {
        tree[i].x2 += 2 * x * tree[i].x + (tree[i].r - tree[i].l + 1) * x * x;
        tree[i].xy += x * tree[i].y + y * tree[i].x + (tree[i].r - tree[i].l + 1) * x * y;
        tree[i].x += (tree[i].r - tree[i].l + 1) * x;
        tree[i].y += (tree[i].r - tree[i].l + 1) * y;
        tree[i].tag_x += x, tree[i].tag_y += y;
        return;
    }
    int mid = (tree[i].l + tree[i].r) >> 1;
    pushdown(i);
    if (mid >= l)
        update(i * 2, l, r, x, y);
    if (mid + 1 <= r)
        update(i * 2 + 1, l, r, x, y);
    pushup(i);
    return;
}
inline void clean(int i, int l, int r) {
    if (tree[i].l >= l && tree[i].r <= r) {
        tree[i].xy = tree[i].x2 = tree[i].r * (tree[i].r + 1) * (2 * tree[i].r + 1) / 6.0 -
                                  tree[i].l * (tree[i].l + 1) * (2 * tree[i].l + 1) / 6.0;
        tree[i].x = tree[i].r = (tree[i].r - tree[i].l + 1) * (tree[i].l + tree[i].r) / 2.0;
        tree[i].tag_x = tree[i].tag_y = 0;
        return;
    }
    pushdown(i);
    int mid = (tree[i].l + tree[i].r) >> 1;
    if (mid >= l)
        clean(i * 2, l, r);
    if (mid + 1 <= r)
        clean(i * 2 + 1, l, r);
    pushup(i);
    return;
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%lf", &x[i]);
    for (int i = 1; i <= n; i++) scanf("%lf", &y[i]);
    build(1, 1, 2*n);
    while (m--) {
        int op, l, r;
        scanf("%d%d%d", &op, &l, &r);
        if (op == 1)
            cout << fixed << setprecision(10) << get_a(1, l, r) << endl;
        else {
            double s, t;
            scanf("%lf%lf", &s, &t);
            if (op == 2)
                update(1, l, r, s, t);
            else if (op == 3) {
                clean(1, l, r);
                update(1, l, r, s, t);
            }
        }
    }
    return 0;
}
2021/10/3 19:44
加载中...