反复求助(UOJ/本地AC)
查看原帖
反复求助(UOJ/本地AC)
361308
Stinger楼主2021/6/12 20:05

为啥交洛谷就WA了啊/yun

#include <cstdio> 
#include <algorithm>
#include <vector>
#include <map>
#include <cmath>
#define int long long

int fst[100005], lst[100005], a[100005], b[200005], pre[100005], S;
int cnt, Cnt[100005], Cnt2[100005], id, now, ans[100005], fa[100005][19], dep[100005];
int v[100005], w[100005], Old[200005], New[200005], p[100005];
int l = 1, r, t;
struct Edge {
	int to, nxt;
} e[200005];
int head[100005], tot;
inline void AddEdge(int u, int v) {
	e[++ tot].to = v, e[tot].nxt = head[u], head[u] = tot;
}
struct Quest {
	int l, r, lca, id, t;
	inline bool operator < (const Quest x) const {
		int lb = (l - 1) / S, lb2 = (x.l - 1) / S, rb = (r - 1) / S, rb2 = (x.r - 1) / S;
		return lb == lb2 ? (rb == rb2 ? t < x.t : rb < rb2) : lb < lb2;
	}
} q[100005];

void dfs(int u) {
	for (int i = 1; i <= 18; ++ i) fa[u][i] = fa[fa[u][i - 1]][i - 1];
	dep[u] = dep[fa[u][0]] + 1, fst[u] = ++ cnt, b[cnt] = u;
	for (int i = head[u]; i; i = e[i].nxt)
		if (e[i].to != fa[u][0]) fa[e[i].to][0] = u, dfs(e[i].to);
	lst[u] = ++ cnt, b[cnt] = u;
}
int query(int u, int v) {
	if (dep[u] < dep[v]) u ^= v ^= u ^= v;
	for (int i = 0; i <= 18; ++ i)
		if (dep[u] - dep[v] & 1 << i) u = fa[u][i];
	if (u == v) return u;
	for (int i = 18; i >= 0; -- i)
		if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
	return fa[u][0];
}
inline void add(int x) {
	int c = a[x];
	Cnt2[x] ^= 1;
	if (!Cnt2[x]) now -= v[c] * w[Cnt[c] --];
	if (Cnt2[x] && ++ Cnt[c] > 0) now += v[c] * w[Cnt[c]];
}
inline void del(int x) {
	int c = a[x];
	Cnt2[x] ^= 1;
	if (!Cnt2[x]) now -= v[c] * w[Cnt[c] --];
	if (Cnt2[x] && ++ Cnt[c] > 0) now += v[c] * w[Cnt[c]];
}
void modify(int p, int x) {
	if (Cnt2[p]) del(p), a[p] = x, add(p);
	a[p] = x;
}

signed main() {
	int n, m, Q, qt = 0;
	scanf("%lld%lld%lld", &n, &m, &Q);
	S = pow(n, 2.0 / 3.0) * sqrt(2);
	for (int i = 1; i <= m; ++ i) scanf("%lld", v + i);
	for (int i = 1; i <= n; ++ i) scanf("%lld", w + i);
	for (int i = 1; i < n; ++ i) {
		int u, v;
		scanf("%lld%lld", &u, &v);
		AddEdge(u, v), AddEdge(v, u);
	}
	for (int i = 1; i <= n; ++ i) scanf("%lld", a + i), pre[i] = a[i];
	dfs(1);
	for (int i = 1, j = 0; i <= Q; ++ i) {
		int opt;
		scanf("%lld", &opt);
		if (opt == 1) {
			int u, v, lca;
			scanf("%lld%lld", &u, &v);
			if (fst[u] > fst[v]) u ^= v ^= u ^= v;
			lca = query(u, v), q[++ qt].id = qt, q[qt].t = j;
			if (lca == u) q[qt].l = fst[u], q[qt].r = fst[v];
			else q[qt].l = lst[u], q[qt].r = fst[v], q[qt].lca = lca;
		} else {
			int x, y;
			scanf("%lld%lld", &x, &y);
			Old[++ j] = pre[x], New[j] = pre[x] = y, p[j] = x;
		}
	}
	std::sort(q + 1, q + qt + 1);
	for (int i = 1; i <= qt; ++ i) {
		while (l > q[i].l) add(b[-- l]);
		while (r < q[i].r) add(b[++ r]);
		while (r > q[i].r) del(b[r --]);
		while (l < q[i].l) del(b[l ++]);
		while (t < q[i].t) ++ t, modify(p[t], New[t]);
		while (t > q[i].t) modify(p[t], Old[t]), t --;
		if (q[i].lca) ans[q[i].id] = v[a[q[i].lca]] * w[Cnt[a[q[i].lca]] + 1];
		ans[q[i].id] += now;
	}
	for (int i = 1; i <= qt; ++ i) printf("%lld\n", ans[i]);
	return 0;
}
2021/6/12 20:05
加载中...