轻重链剖分 pts20
查看原帖
轻重链剖分 pts20
519384
Link_Cut_Y楼主2022/1/8 16:58
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>

using namespace std;

const int N = 3e4 + 10, M = N << 1;
int h[N], e[M], ne[M], idx;
int w[N], nw[N], id[N], cnt;
int fa[N], dep[N], top[N], son[N];
int sz[N];
int n, m;

struct Tree
{
	int l, r, add, sum, maxn;
}tr[N << 2];

int length(int u)
{
	return tr[u].r - tr[u].l + 1;
}

void add(int a, int b)
{
	e[ ++ idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs1(int u, int father, int depth)
{
	fa[u] = father, dep[u] = depth, sz[u] = 1;
	
	for (int i = h[u]; i; i = ne[i])
	{
		int j = e[i];
		if (j == father) continue;
		dfs1(j, u, depth + 1);
		sz[u] += sz[j];
		
		if (sz[son[u]] < sz[j]) son[u] = j;
	}
}

void dfs2(int u, int t)
{
	top[u] = t, id[u] = ++ cnt, nw[cnt] = w[u];
	if (!son[u]) return;
	dfs2(son[u], t);
	
	for (int i = h[u]; i; i = ne[i])
	{
		int j = e[i];
		if (j == son[u] || j == fa[u]) continue;
		dfs2(j, j);
	}
}

void pushup(int u)
{
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
	tr[u].maxn = max(tr[u << 1].maxn, tr[u << 1 | 1].maxn);
}

void pushdown(int u)
{
	if (tr[u].add)
	{
		tr[u << 1].add += tr[u].add, tr[u << 1 | 1].add += tr[u].add;
		tr[u << 1].sum += tr[u].add * length(u << 1), tr[u << 1 | 1].sum += tr[u].add * length(u << 1 | 1);
		tr[u].add = 0;
	}
}

void build(int u, int l, int r)
{
	tr[u] = {l, r, 0, nw[r], nw[r]};
	if (l == r) return;
	
	int mid = l + r >> 1;
	build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
	pushup(u);
}

void modify(int u, int x, int v)
{
	if (tr[u].l == tr[u].r)
	{
		tr[u].maxn = tr[u].sum = v;
		return;
	}
	
	int mid = (tr[u].l + tr[u].r) >> 1;
	if (x <= mid) modify(u << 1, x, v);
	else modify(u << 1 | 1, x, v);
	
	pushup(u);
}

int query1(int u, int l, int r) // query sum
{
	if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
	
	int mid = (tr[u].l + tr[u].r) >> 1;
	int res = 0;
	pushdown(u);
	
	if (l <= mid) res += query1(u << 1, l, r);
	if (r > mid) res += query1(u << 1 | 1, l, r);
	
	return res;
}

int query2(int u, int l, int r) // query max
{
	if (tr[u].l >= l && tr[u].r <= r) return tr[u].maxn;
	
	int mid = (tr[u].l + tr[u].r) >> 1;
	int maxn = -0x3f3f3f3f;
	pushdown(u);
	
	if (l <= mid) maxn = max(maxn, query2(u << 1, l, r));
	if (r > mid) maxn = max(maxn, query2(u << 1 | 1, l, r));
	
	return maxn;
}

int query_sum(int u, int v)
{
	int res = 0;
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		res += query1(1, id[top[u]], id[u]);
		u = fa[top[u]];
	}
	
	if (dep[u] < dep[v]) swap(u, v);
	res += query1(1, id[v], id[u]);
	
	return res;
}

int query_max(int u, int v)
{
	int res = -0x3f3f3f3f;
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		res = max(res, query2(1, id[top[u]], id[u]));
		u = fa[top[u]];
	}
	
	if (dep[u] < dep[v]) swap(u, v);
	res = max(res, query2(1, id[v], id[u]));
	
	return res;
}

int main()
{
	scanf("%d", &n);
	for (int i = 1; i <= n - 1; i ++ )
	{
		int a, b;
		scanf("%d%d", &a, &b);
		add(a, b), add(b, a);
	}
	
	for (int i = 1; i <= n; i ++ )
		scanf("%d", &w[i]);
	
	dfs1(1, -1, 1), dfs2(1, 1), build(1, 1, n);
	scanf("%d", &m);
	
	while (m -- )
	{
		char op[7];
		int u, v;
		scanf("%s%d%d", op, &u, &v);
		if (op[1] == 'H') modify(1, u, v);
		if (op[1] == 'M') printf("%d\n", query_max(u, v));
		if (op[1] == 'S') printf("%d\n", query_sum(u, v));
	}
	
	return 0;
}
2022/1/8 16:58
加载中...