线段树矩阵乘法 MLE 求条,玄关qwq
查看原帖
线段树矩阵乘法 MLE 求条,玄关qwq
946909
Chase12345楼主2025/6/29 21:35

本地样例过了,然而 MLE 飞了,求掉qwq

#include <bits/stdc++.h>
using namespace std;

using i64 = long long;
using vi = vector <i64>;
using M = vector <vi>;
const int MOD = 998244353, N = 2.5e5 + 5;
M A(4, vi(4, 0)), B(4, vi(1, 0));

M multi(M a, M b) {
	int n = a.size(), t = a[0].size(), m = b[0].size();
	M ret(n, vi(m, 0));
	for (int i = 0; i < n; i++)
		for (int j = 0; j < m; j++)
			for (int k = 0; k < t; k++)
				ret[i][j] = (ret[i][j] + 1LL * a[i][k] * b[k][j] % MOD) % MOD;
	return ret;
}

M add(M a, M b) {
	int n = a.size(), m = a[0].size();
	M ret(n, vi(m, 0));
	for (int i = 0; i < n; i++)
		for (int j = 0; j < m; j++)
			ret[i][j] = (ret[i][j] + a[i][j] + b[i][j]) % MOD;
	return ret;
}

struct Tree {
	int l, r;
	M sum = M(4, vi(1, 0)), mul = M(4, vi(4, 0));
} tree[N << 2];
int a[N], b[N], c[N];

void pushdown(int p) {
	for (int i = 0; i < 4; i++)
		tree[p].sum[i][0] = (tree[p << 1].sum[i][0] + tree[p << 1 | 1].sum[i][0]) % MOD;
}

void spread(int p) {
	if (tree[p].mul != A) {
		tree[p << 1].sum = multi(tree[p].mul, tree[p << 1].sum);
		tree[p << 1 | 1].sum = multi(tree[p].mul, tree[p << 1 | 1].sum);
		tree[p << 1].mul = multi(tree[p].mul, tree[p << 1].mul);
		tree[p << 1 | 1].mul = multi(tree[p].mul, tree[p << 1 | 1].mul);
		return;
	}
}

void build(int p, int l, int r) {
	tree[p].l = l;
	tree[p].r = r;
	tree[p].mul = A;
	if (l == r) {
		tree[p].sum[0][0] = a[l];
		tree[p].sum[1][0] = b[l];
		tree[p].sum[2][0] = c[l];
		tree[p].sum[3][0] = 1;
		return;
	}
	int mid = l + r >> 1;
	build(p << 1, l, mid);
	build(p << 1 | 1, mid + 1, r);
	pushdown(p);
}

void mul(int p, int l, int r, M k) {
	if (tree[p].r < l || tree[p].l > r)
		return;
	if (l <= tree[p].l && tree[p].r <= r) {
		tree[p].mul = multi(k, tree[p].mul);
		tree[p].sum = multi(k, tree[p].sum);
		return;
	}
	spread(p);
	mul(p << 1, l, r, k);
	mul(p << 1 | 1, l, r, k);
	pushdown(p);
}

M query(int p, int l, int r) {
	if (tree[p].r < l || tree[p].l > r)
		return B;
	if (l <= tree[p].l && tree[p].r <= r)
		return tree[p].sum;
	spread(p);
	return add(query(p << 1, l, r), query(p << 1 | 1, l, r));
}

int main() {
	for (int i = 0; i < 4; i++)
		A[i][i] = 1;
	int n, m;
	cin >> n;
	for (int i = 1; i <= n; i++)
		cin >> a[i] >> b[i] >> c[i];
	build(1, 1, n);
	cin >> m;
	while (m--) {
		int op, l, r, v;
		cin >> op >> l >> r;
		if (op == 1) {
			A[0][1]++;
			mul(1, l, r, A);
			A[0][1]--;
		} else if (op == 2) {
			A[1][2]++;
			mul(1, l, r, A);
			A[1][2]--;
		} else if (op == 3) {
			A[2][0]++;
			mul(1, l, r, A);
			A[2][0]--;
		} else if (op == 4) {
			cin >> v;
			A[0][3] += v;
			mul(1, l, r, A);
			A[0][3] -= v;
		} else if (op == 5) {
			cin >> v;
			A[1][1] += (v - 1);
			mul(1, l, r, A);
			A[1][1] -= (v - 1);
		} else if (op == 6) {
			cin >> v;
			A[2][2]--;
			A[2][3] += v;
			mul(1, l, r, A);
			A[2][2]++;
			A[2][3] -= v;
		} else if (op == 7) {
			M res = query(1, l, r);
			cout << res[0][0] << ' ' << res[1][0] << ' ' << res[2][0] << '\n';
		}
	}
	return 0;
}
2025/6/29 21:35
加载中...