萌新刚学树剖,10pts求助
查看原帖
萌新刚学树剖,10pts求助
217577
kemkra楼主2021/4/9 22:38
#include <cstdio>
#include <algorithm>

using namespace std;

const int N = 4e4;

char s[10];
int n, q, cnt;
int a[N], w[N];
int head[N], to[N], nxt[N];
int dep[N], fa[N], siz[N], son[N], top[N], id[N];
int sum[N], mav[N], lc[N], rc[N], root;

void add(int u, int v) {
	to[++cnt] = v;
	nxt[cnt] = head[u];
	head[u] = cnt;
}

void dfs1(int x, int f, int d) {
	fa[x] = f;
	dep[x] = d;
	siz[x] = 1;
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == f) continue;
		dfs1(y, x, d + 1);
		siz[x] += siz[y];
		if (siz[y] > siz[son[x]]) son[x] = y;
	}
}

void dfs2(int x, int u) {
	id[x] = ++cnt;
	w[cnt] = a[x];
	top[x] = u;
	if (!son[x]) return;
	dfs2(son[x], u);
	for (int i = head[x]; i; i = nxt[i]) {
		int y = to[i];
		if (y == fa[x] || y == son[x]) continue;
		dfs2(y, y);
	}
}

void pushup(int x) {
	sum[x] = sum[lc[x]] + sum[rc[x]];
	mav[x] = max(mav[lc[x]], mav[rc[x]]);
}

int build(int l, int r) {
	int u = ++cnt;
	if (l == r) sum[u] = mav[u] = w[l];
	else {
		int mid = (l + r) >> 1;
		lc[u] = build(l, mid);
		rc[u] = build(mid + 1, r);
		pushup(u);
	}
	return u;
}

void update(int u, int l, int r, int x, int d) {
	if (l == r) {
		sum[u] = mav[u] = d;
		return;
	}
	int mid = (l + r) >> 1;
	if (x <= mid) update(lc[u], l, mid, x, d);
	else update(rc[u], mid + 1, r, x, d);
	pushup(u);
}

int querymax(int u, int l, int r, int x, int y) {
	if (l >= x && r <= y) return mav[u];
	int mid = (l + r) >> 1, v = 0;
	if (x <= mid) v = max(v, querymax(lc[u], l, mid, x, y));
	if (y > mid) v = max(v, querymax(rc[u], mid + 1, r, x, y));
	return v;
}

int querysum(int u, int l, int r, int x, int y) {
	if (l >= x && r <= y) return sum[u];
	int mid = (l + r) >> 1, v = 0;
	if (x <= mid) v += querysum(lc[u], l, mid, x, y);
	if (y > mid) v += querysum(rc[u], mid + 1, r, x, y);
	return v;
}

int qmax(int x, int y) {
	int v = -0x80000000;
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		v = max(v, querymax(root, 1, n, id[top[x]], id[x]));
		x = fa[top[x]];
	}
	if (dep[x] > dep[y]) swap(x, y);
	v = max(v, querymax(root, 1, n, id[x], id[y]));
	return v;
}

int qsum(int x, int y) {
	int v = 0;
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]]) swap(x, y);
		v += querysum(root, 1, n, id[top[x]], id[x]);
		x = fa[top[x]];
	}
	if (dep[x] > dep[y]) swap(x, y);
	v += querysum(root, 1, n, id[x], id[y]);
	return v;
}

int main() {
	scanf("%d", &n);
	for (int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		add(u, v);
		add(v, u);
	}
	for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
	dfs1(1, 0, 1);
	cnt = 0;
	dfs2(1, 1);
	cnt = 0;
	root = build(1, n);
	scanf("%d", &q);
	for (int i = 1, x, y; i <= q; i++) {
		scanf("%s%d%d", s, &x, &y);
		if (s[0] == 'C') update(root, 1, n, id[x], y);
		if (s[1] == 'M') printf("%d\n", qmax(x, y));
		if (s[1] == 'S') printf("%d\n", qsum(x, y));
	}
	return 0;
}
2021/4/9 22:38
加载中...