萌新求助树剖12pts
查看原帖
萌新求助树剖12pts
519384
Link_Cut_Y楼主2022/1/26 07:51

码风略丑,凑活看吧qwq

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>

using namespace std;

const int N = 200010, M = N << 1, INF = 0x3f3f3f3f;
int h[N], e[M], ne[M], val[M], idx;
int dep[N], fa[N], sz[N], top[N], son[N];
int w[N], nw[N], id[N], cnt;
int n, m;
int from[N], to[N];

void add(int a, int b, int c)
{
	e[ ++ idx] = b, val[idx] = c, ne[idx] = h[a], h[a] = idx;
}

void dfs1(int u, int father, int depth)
{
	dep[u] = depth, fa[u] = father, sz[u] = 1;
	
	for (int i = h[u]; i; i = ne[i])
	{
		int j = e[i];
		if (j == father) continue;
		dfs1(j, u, depth + 1);
		w[j] = val[i];
		sz[u] += sz[j];
		if (sz[son[u]] < sz[j]) son[u] = j;
	}
}

void dfs2(int u, int t)
{
	top[u] = t, id[u] = ++ cnt, nw[cnt] = w[u];
	if (!son[u]) return;
	dfs2(son[u], t);
	
	for (int i = h[u]; i; i = ne[i])
	{
		int j = e[i];
		if (j == fa[u] || j == son[u]) continue;
		dfs2(j, j);
	}
}

struct Tree
{
	int l, r;
	int rev, sum, maxn, minn;
}tr[N << 2];

void pushup(int u)
{
	tr[u].maxn = max(tr[u << 1].maxn, tr[u << 1 | 1].maxn);
	tr[u].minn = min(tr[u << 1].minn, tr[u << 1 | 1].minn);
	tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u)
{
	if (tr[u].rev)
	{
		tr[u << 1].rev ^= 1, tr[u << 1 | 1].rev ^= 1;
		int maxl = tr[u << 1].maxn, minl = tr[u << 1].minn;
		tr[u << 1].maxn = -minl, tr[u << 1].minn = -maxl;
		int maxr = tr[u << 1 | 1].maxn, minr = tr[u << 1 | 1].minn;
		tr[u << 1 | 1].maxn = -minr, tr[u << 1 | 1].minn = -maxr;
		tr[u << 1].sum *= -1, tr[u << 1 | 1].sum *= -1;
		tr[u].rev = 0;
	}
}

void build(int u, int l, int r)
{
	tr[u] = {l, r, 0, 0, -INF, INF};
	if (l == r)
	{
		tr[u].maxn = tr[u].minn = tr[u].sum = nw[r];
		return;
	}
	
	int mid = l + r >> 1;
	build (u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
	pushup(u);
}

void modify(int u, int x, int v)
{
	if (tr[u].l == x && tr[u].r == x)
	{
		tr[u].maxn = tr[u].minn = tr[u].sum = v;
		return;
	}
	
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if (x <= mid) modify(u << 1, x, v);
	else modify(u << 1 | 1, x, v);
	pushup(u);
}

void reverse_interval(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r)
	{
		tr[u].rev ^= 1;
		int maxv = tr[u].maxn, minv = tr[u].minn;
		tr[u].minn = -maxv, tr[u].maxn = -minv, tr[u].sum *= -1;
		return;
	}
	
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1;
	if (l <= mid) reverse_interval(u << 1, l, r);
	if (r > mid) reverse_interval(u << 1 | 1, l, r);
	pushup(u);
}

int query_sum(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
	
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1, res = 0;
	if (l <= mid) res = query_sum(u << 1, l, r);
	if (r > mid) res += query_sum(u << 1 | 1, l, r);
	
	return res;
}

int query_max(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r) return tr[u].maxn;
	
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1, res = -INF;
	if (l <= mid) res = query_max(u << 1, l, r);
	if (r > mid) res = max(res, query_max(u << 1 | 1, l, r));
	
	return res;
}

int query_min(int u, int l, int r)
{
	if (tr[u].l >= l && tr[u].r <= r) return tr[u].minn;
	
	pushdown(u);
	int mid = tr[u].l + tr[u].r >> 1, res = INF;
	if (l <= mid) res = query_min(u << 1, l, r);
	if (r > mid) res = min(res, query_min(u << 1 | 1, l, r));
	
	return res;
}

void reverse_path(int u, int v)
{
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		reverse_interval(1, id[top[u]] + 1, id[u]);
		u = fa[top[u]];
	}
	
	if (dep[u] < dep[v]) swap(u, v);
	reverse_interval(1, id[v] + 1, id[u]);
}

int query_path_max(int u, int v)
{
	int res = -INF;
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		res = max(res, query_max(1, id[top[u]] + 1, id[u]));
		u = fa[top[u]];
	}
	
	if (dep[u] < dep[v]) swap(u, v);
	res = max(res, query_max(1, id[v] + 1, id[u]));
	
	return res;
}

int query_path_min(int u, int v)
{
	int res = INF;
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		res = min(res, query_min(1, id[top[u]] + 1, id[u]));
		u = fa[top[u]];
	}
	
	if (dep[u] < dep[v]) swap(u, v);
	res = min(res, query_min(1, id[v] + 1, id[u]));
	
	return res;
}

int query_path_sum(int u, int v)
{
	int res = 0;
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		res += query_sum(1, id[top[u]] + 1, id[u]);
		u = fa[top[u]];
	}
	
	if (dep[u] < dep[v]) swap(u, v);
	res += query_sum(1, id[v] + 1, id[u]);
	
	return res;
}

int main()
{	
	scanf("%d", &n);
	
	for (int i = 1; i <= n - 1; i ++ )
	{
		int W;
		scanf("%d%d%d", &from[i], &to[i], &W);
		from[i] += 1, to[i] += 1;
		add(from[i], to[i], W), add(to[i], from[i], W);
	}
	
	dfs1(1, 0, 1), dfs2(1, 1), build(1, 1, n);
	
	scanf("%d", &m);
	while (m -- )
	{
		char op[4];
		int a, b;
		scanf("%s%d%d", op, &a, &b);
		
		if (op[0] == 'C')
		{
			if (dep[from[a]] > dep[to[a]]) modify(1, id[from[a]], b);
			else modify(1, id[to[a]], b);
		}
		else if (op[0] == 'N')
			reverse_path(a + 1, b + 1);
		else if (op[0] == 'S')
			printf("%d\n", query_path_sum(a + 1, b + 1));
		else if (op[0] == 'M' && op[1] == 'A')
			printf("%d\n", query_path_max(a + 1, b + 1));
		else
			printf("%d\n", query_path_min(a + 1, b + 1));
	}
	
	return 0;
}

2022/1/26 07:51
加载中...