萌新刚学线段树,40pts WA求助
查看原帖
萌新刚学线段树,40pts WA求助
675275
TREE_OI_offical楼主2022/2/6 16:05

本号是小号。 大号@ternary_tree

#include <bits/stdc++.h>
#define maxn (1 << 20)

using namespace std;

int sum[maxn << 2], s2[maxn << 2];
int tag[maxn << 2];
auto ls = [] (int p) { return p << 1; };
auto rs = [] (int p) { return (p << 1) | 1; };

void pushup(int node) {
	sum[node] = sum[ls(node)] + sum[rs(node)];
	s2[node] = s2[ls(node)] + s2[rs(node)];
}

void pushdown(int node, int l, int r) {
	if (!tag[node]) return;
	int mid = (l + r) >> 1;
	tag[ls(node)] += tag[node];
	tag[rs(node)] += tag[node];
	s2[ls(node)] += (tag[node] * sum[ls(node)] * 2 + tag[node] * tag[node] * (mid - l + 1)); 
	s2[rs(node)] += (tag[node] * sum[rs(node)] * 2 + tag[node] * tag[node] * (r - mid)); 
	sum[ls(node)] += (tag[node] * (mid - l + 1));
	sum[rs(node)] += (tag[node] * (r - mid));
	tag[node] = 0;
}

void add(int node, int l, int r, int ql, int qr, double k) {
	if (ql <= l && r <= qr) {
		tag[node] += k;
		s2[node] += sum[node] * k * 2 + (r - l + 1) * k * k;
		sum[node] += k * (r - l + 1);
		return;
	} 
	int mid = (l + r) >> 1;
	pushdown(node, l, r);
	if (ql <= mid) {
		add(ls(node), l, mid, ql, qr, k);
	}
	if (mid < qr) {
		add(rs(node), mid + 1, r, ql, qr, k);
	}
	pushup(node);
}

void build(int &num) {
	int tmp;
	for (int i = 1; i <= num; i++) {
		cin >> tmp;
		add(1, 1, num, i, i, tmp);
	}
}

int getsum(int node, int l, int r, int ql, int qr) {
	if (ql <= l && r <= qr) {
		return sum[node];
	}
	int mid = (l + r) >> 1;
	int res = 0;
	pushdown(node, l, r);
	if (ql <= mid) {
		res += getsum(ls(node), l, mid, ql, qr);
	}
	if (mid < qr) {
		res += getsum(rs(node), mid + 1, r, ql, qr);
	}
	return res;
}

int gets2(int node, int l, int r, int ql, int qr) {
	if (ql <= l && r <= qr) {
		return s2[node];
	}
	int mid = (l + r) >> 1;
	int res = 0;
	pushdown(node, l, r);
	if (ql <= mid) {
		res += gets2(ls(node), l, mid, ql, qr);
	}
	if (mid < qr) {
		res += gets2(rs(node), mid + 1, r, ql, qr);
	}
	return res;
}

int n, m, op, x, y;
int k;

int main()
{
    cin >> n >> m;
    build(n);
    for (int i = 1; i <= m; i++) {
    	cin >> op;
    	if (op == 1) {
			cin >> x >> y >> k;
			add(1, 1, n, x, y, k);
		} else if (op == 2) {
			cin >> x >> y;
			int nrt = getsum(1, 1, n, x, y);
			int dnt = y - x + 1;
			if (nrt == 0) {
				cout << "0/1" << endl;
				continue;
			}
			int g = __gcd(nrt, dnt);
			
			nrt /= g;
			dnt /= g;
			cout << nrt << "/" << dnt << endl;
		} else {
			cin >> x >> y;
			int a = getsum(1, 1, n, x, y);
			int b = gets2(1, 1, n, x, y);
			int n = (y - x + 1);
			int nrt = b * n - a * a;
			int dnt = n * n;
			if (nrt == 0) {
				cout << "0/1" << endl;
				continue;
			}
			int g = __gcd(nrt, dnt);
			
			nrt /= g;
			dnt /= g;
			cout << nrt << "/" << dnt << endl;
		}
		
	}
    return 0;
}
2022/2/6 16:05
加载中...