大家好,我仔细看了题解的代码,我觉得长得挺像的……但是我不明白哪里存在 bug,所以来请求大佬援助,感谢!!
#include "bits/stdc++.h"
using i64 = int64_t;
using u64 = uint64_t;
class SegmentTree {
struct Node {
size_t const lo, hi;
i64 sum;
i64 tag; // lazy tag for children's modification.
Node *left;
Node *right;
};
public:
SegmentTree(i64 *bg, i64 *ed) : nodes_(), root_(init(0, ed - bg, bg)) {}
i64 interval_sum(size_t const lo, size_t const hi) { return interval_sum(root_, lo, hi); }
void interval_add(size_t const lo, size_t const hi, i64 val) { interval_add(root_, lo, hi, val); }
private:
std::deque<Node> nodes_;
Node *root_;
i64 interval_sum(Node *const node, size_t const lo, size_t const hi) {
if (node->lo >= hi || node->hi <= lo) {
return 0;
}
if (node->lo >= lo && node->hi <= hi) {
return node->sum;
}
if (node->tag != 0) {
node->left->sum += node->tag * (node->left->hi - node->left->lo);
node->left->tag += node->tag;
node->right->sum += node->tag * (node->right->hi - node->right->lo);
node->right->tag += node->tag;
node->tag = 0;
}
return interval_sum(node->left, lo, hi) + interval_sum(node->right, lo, hi);
}
void interval_add(Node *const node, size_t const lo, size_t const hi, i64 val) {
if (node->lo >= hi || node->hi <= lo) {
return;
}
if (node->lo >= lo && node->hi <= hi) {
node->sum += val * (node->hi - node->lo);
node->tag += val;
return;
}
node->sum += val * (hi >= node->hi ? node->hi - lo : hi - node->lo);
interval_add(node->left, lo, hi, val);
interval_add(node->right, lo, hi, val);
}
Node *init(size_t const lo, size_t const hi, i64 *const arr) {
if (lo + 1 == hi) {
nodes_.push_back({.lo = lo, .hi = hi, .sum = arr[lo], .tag = 0, .left = nullptr, .right = nullptr});
return &nodes_.back();
}
auto const kMid = lo + (hi - lo) / 2;
auto left = init(lo, kMid, arr);
auto right = init(kMid, hi, arr);
nodes_.push_back({.lo = lo, .hi = hi, .sum = left->sum + right->sum, .tag = 0, .left = left, .right = right});
return &nodes_.back();
}
};
int main() {
u64 n, m;
std::cin >> n >> m;
std::vector<i64> arr(n);
for (auto &i : arr) {
std::cin >> i;
}
SegmentTree st(&arr[0], &arr.back() + 1);
for (u64 i = 0; i < m; ++i) {
u64 order, x, y;
std::cin >> order >> x >> y;
if (order == '1') {
i64 k;
std::cin >> k;
st.interval_add(x - 1, y, k);
} else {
std::cout << st.interval_sum(x - 1, y) << '\n';
}
}
}