WA 0
查看原帖
WA 0
91851
Evan704楼主2018/12/16 16:02

RT

#include<iostream>
#include<cstdio>
using namespace std;
const int N=3e4+5;
int n,m,head[N],sum[N<<2],big[N<<2],cnt,a[N];
int d[N],f[N],son[N],size[N],id[N],rk[N],top[N];
struct edge{
	int v,nex;
}e[N<<2];
void dfs(int u){
	size[u]=1;
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].v;
		if(v==f[u])continue;
		f[v]=u;
		d[v]=d[u]+1;
		dfs(v);
		size[u]+=size[v];
		if(size[son[u]]<size[v])son[u]=v;
	}
}
void dfs2(int u,int t){
	top[u]=t;id[u]=++cnt;rk[cnt]=u;
	if(son[u])dfs2(son[u],t);
	for(int i=head[u];i;i=e[i].nex){
		int v=e[i].v;
		if(v==f[u]||v==son[u])continue;
		dfs2(v,v);
	}
}
int found1(int nl,int nr,int l,int r,int id){
	if(nl<=l&&r<=nr)return big[id];
	int mid=(l+r)>>1,res=0;
	if(nl<=mid)res=max(res,found1(nl,nr,l,mid,id<<1));
	if(nr>mid)res=max(res,found1(nl,nr,mid+1,r,id<<1|1));
	return res;
}
int found2(int nl,int nr,int l,int r,int id){
	if(nl<=l&&r<=nr)return sum[id];
	int mid=(l+r)>>1,res=0;
	if(nl<=mid)res+=found2(nl,nr,l,mid,id<<1);
	if(nr>mid)res+=found2(nl,nr,mid+1,r,id<<1|1);
	return res;
}
int ask1(int x,int y){
	int res=0;
	while(top[x]!=top[y]){
		if(d[top[x]]<d[top[y]])swap(x,y);
		res=max(res,found1(id[top[x]],id[x],1,n,1));
		x=f[top[x]];
	}
	if(d[x]>d[y])swap(x,y);
	return max(res,found1(id[x],id[y],1,n,1));
}
int ask2(int x,int y){
	int res=0;
	while(top[x]!=top[y]){
		if(d[top[x]]<d[top[y]])swap(x,y);
		res+=found2(id[top[x]],id[x],1,n,1);
		x=f[top[x]];
	}
	if(d[x]>d[y])swap(x,y);
	return res+found2(id[x],id[y],1,n,1);
}
void add(int u,int v){
	e[++cnt]=(edge){v,head[u]};
	head[u]=cnt;
}
void build(int l,int r,int id){
	if(l==r){
		big[id]=sum[id]=a[rk[l]];
		return;
	}
	int mid=(l+r)>>1;
	build(l,mid,id<<1);
	build(mid+1,r,id<<1|1);
	big[id]=max(big[id<<1],big[id<<1|1]);
	sum[id]=sum[id<<1]+sum[id<<1|1];
}
void add(int x,int k,int l,int r,int id){
	if(l==r){
		big[id]=sum[id]=k;
		return;
	}
	int mid=(l+r)>>1;
	if(x<=mid)add(x,k,l,mid,id<<1);
	else add(x,k,mid+1,r,id<<1|1);
	big[id]=max(big[id<<1],big[id<<1|1]);
	sum[id]=sum[id<<1]+sum[id<<1|1];
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		add(u,v);
		add(v,u);
	}
	for(int i=1;i<=n;i++)scanf("%d",&a[i]);
	cnt=0;dfs(1);dfs2(1,1);
	build(1,n,1);
	scanf("%d",&m);
	char ch[10];
	int x,y;
	while(m--){
		scanf("%s%d%d",ch,&x,&y);
		if(ch[1]=='M')printf("%d\n",ask1(x,y));
		else if(ch[1]=='S')printf("%d\n",ask2(x,y));
		else add(x,y,1,n,1);
	}
}
2018/12/16 16:02
加载中...