树剖板子那题,
#include<bits/stdc++.h>
using namespace std;
#define mkp make_pair
#define pb push_back
#define ls (rt << 1)
#define rs (rt << 1 | 1)
typedef long long ll;
const int N = 30010, inf = 0x3f3f3f3f;
int n, tim, fat[N], sz[N], son[N], dep[N], w[N], dfn[N], top[N], id[N];
int e, to[N << 1], nxt[N << 1], hd[N];
struct node{
int mx, sum;
node operator + (const node x) const {
return (node){max(mx, x.mx), sum + x.sum};
}
}tr[N << 1];
void add(int a, int b){
to[++e] = b; nxt[e] = hd[a]; hd[a] = e;
}
void dfs(int u, int fa){
fat[u] = fa; dep[u] = dep[fa] + 1;
sz[u] = 1; son[u] = 0;
for(int i = hd[u]; i; i = nxt[i]){
int v = to[i]; if(v == fa) continue;
dfs(v, u); sz[u] += sz[v];
if(sz[v] > sz[son[u]]) son[u] = v;
}
return;
}
void dfs2(int u, int topf){
dfn[u] = ++tim; id[tim] = u;
top[u] = topf;
if(!son[u]) return;
dfs2(son[u], topf);
for(int i = hd[u]; i; i = nxt[i]){
int v = to[i]; if(v == fat[u] || v == son[u]) continue;
dfs2(v, v);
}
return;
}
void build(int rt, int l, int r){
if(l == r)
return tr[rt].mx = tr[rt].sum = w[id[l]], void();
int mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
tr[rt] = tr[ls] + tr[rs];
return;
}
void update(int rt, int l, int r, int pos, int va){
if(l == r)
return tr[rt].mx = tr[rt].sum = va, void();
int mid = (l + r) >> 1;
if(pos <= mid) update(ls, l, mid, pos, va);
else update(rs, mid + 1, r, pos, va);
tr[rt] = tr[ls] + tr[rs];
return;
}
node query(int rt, int l, int r, int L, int R){
if(L <= l && r <= R) return tr[rt];
int mid = (l + r) >> 1;
node ret = (node){-inf, 0};
if(L <= mid) ret = ret + query(ls, l, mid, L, R);
if(R > mid) ret = ret + query(rs, mid + 1, r, L, R);
return ret;
}
node qrange(int x, int y){
node ret = (node){-inf, 0};
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
ret = ret + query(1, 1, n, dfn[top[x]], dfn[x]);
x = fat[top[x]];
}
if(dep[x] > dep[y]) swap(x, y);
ret = ret + query(1, 1, n, dfn[x], dfn[y]);
return ret;
}
void updrange(int x, int y){
update(1, 1, n, dfn[x], y);
}
int main(){
scanf("%d", &n);
for(int i = 1, u, v; i < n; i++){
scanf("%d%d", &u, &v);
add(u, v); add(v, u);
}
for(int i = 1; i <= n; i++)
scanf("%d", &w[i]);
dfs(1, 0); dfs2(1, 1);
build(1, 1, n);
int m; scanf("%d", &m);
for(int i = 1, x, y; i <= m; i++){
char ch[15]; scanf("%s%d%d", ch, &x, &y);
if(ch[0] == 'C') updrange(x, y);
else if(ch[1] == 'M') printf("%d\n", qrange(x, y).mx);
else printf("%d\n", qrange(x, y).sum);
}
return 0;
}