newbie 20 pts求助
查看原帖
newbie 20 pts求助
138960
Tenshi楼主2021/4/17 21:55

找不到哪里有问题OTZ,请求聚聚帮忙看看。(保证码风不毒瘤

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;

typedef long long ll;

const int INF=0x3f3f3f3f;
const int N=3e4+5, M=N<<1;
struct point{
	int to,next;
}e[M];
int h[N], tot;
void add(int u, int v){e[tot].to=v, e[tot].next=h[u], h[u]=tot++;}

struct node{
	int l, r;
	int v;
	ll sum;
}tr[N<<2];

int w[N];
int fa[N], dep[N], son[N], sz[N];
int top[N], cnt, id[N], nw[N]; 

void dfs1(int u, int father, int depth){
	dep[u]=depth, sz[u]=1, fa[u]=father;
	for(int i=h[u]; ~i; i=e[i].next){
		int go=e[i].to;
		if(go==father) continue;
		dfs1(go, u, depth+1);
		sz[u]+=sz[go];
		if(sz[go]>sz[son[u]]) son[u]=go;
	}	
}

void dfs2(int u, int t){
	top[u]=t, id[u]=++cnt, nw[id[u]]=w[u];
	if(!son[u]) return;
	dfs2(son[u], t);
	for(int i=h[u]; ~i; i=e[i].next){
		int go=e[i].to;
		if(go==fa[u] || go==son[u]) continue;
		dfs2(go, go);
	}
}

int ls(int u){return u<<1;}
int rs(int u){return u<<1|1;}

void pushup(int u){
	tr[u].sum=tr[ls(u)].sum+tr[rs(u)].sum;
	tr[u].v=max(tr[ls(u)].v, tr[rs(u)].v);
}

void build(int u, int l, int r){
	if(l==r) tr[u]={l, r, nw[id[r]], nw[id[r]]};
	else{
		tr[u]={l, r};
		int mid=l+r>>1;
		build(ls(u), l, mid), build(rs(u), mid+1, r);
		pushup(u);
	}	
}

void update(int u, int x, int k){
	if(tr[u].r==tr[u].l && tr[u].l==x) tr[u].v=k, tr[u].sum=k;
	else{
		int mid=tr[u].l+tr[u].r>>1;
		if(mid>=x) update(ls(u), x, k);
		else update(rs(u), x, k);
		pushup(u);
	}
}

int query_max(int u, int l, int r){
	if(l<=tr[u].l && tr[u].r<=r) return tr[u].v;
	else{
		int res=-INF;
		int mid=tr[u].l+tr[u].r>>1;
		if(l<=mid) res=max(res, query_max(ls(u), l, r));
		if(r>mid) res=max(res, query_max(rs(u), l, r));
		return res;
	}
}

ll query_sum(int u, int l, int r){
	if(l<=tr[u].l && tr[u].r<=r) return tr[u].sum;
	else{
		ll res=0;
		int mid=tr[u].l+tr[u].r>>1;
		if(l<=mid) res+=query_sum(ls(u), l, r);
		if(r>mid) res+=query_sum(rs(u), l, r);
		return res;
	}
}

void update_point(int u, int k){
	update(1, id[u], k);
}

ll query_path_sum(int u, int v){
	ll res=0;
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]) swap(u, v);
		res+=query_sum(1, id[top[u]], id[u]);
		u=fa[top[u]];
	}
	if(dep[u]<dep[v]) swap(u, v);
	res+=query_sum(1, id[v], id[u]);
	return res;
}

int query_path_max(int u, int v){
	int res=-INF;
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]) swap(u, v);
		res=max(res, query_max(1, id[top[u]], id[u]));
		u=fa[top[u]];
	}
	if(dep[u]<dep[v]) swap(u, v);
	res=max(res, query_max(1, id[v], id[u]));
	return res;
}

int main(){
	memset(h, -1, sizeof h);
	int n; cin>>n;
	int ed=n-1;
	while(ed--){
		int u, v; cin>>u>>v;
		add(u, v), add(v, u);
	}
	
	for(int i=1;i<=n;i++) cin>>w[i];
	
	dfs1(1, -1, 1), dfs2(1, 1);
	build(1, 1, n);
		
	int q; cin>>q;
	while(q--){
		string op; cin>>op;
		int u; cin>>u;
		if(op=="CHANGE"){
			int k; cin>>k;
			update_point(u, k);
		}else if(op=="QMAX"){
			int v; cin>>v;
			cout<<query_path_max(u, v)<<endl;
		}else{
			int v; cin>>v;
			cout<<query_path_sum(u, v)<<endl;
		}
	}
    return 0;
}
2021/4/17 21:55
加载中...