萌新 WA on test 21 求助
查看原帖
萌新 WA on test 21 求助
98618
Provicy楼主2021/2/25 13:49

RT。有人错过这个点吗,或者说一下这题易挂点也行/kel

#include <bits/stdc++.h>
#pragma GCC optimize(3)
#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
using namespace std; const int N=300010, INF=1e9;
inline int read()
{
	int s=0, w=1; ri char ch=getchar();
	while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
	while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
	return s*w;
}
int n,a[N],siz[N],root,maxx,book[N],d[N],Ans;
int head[N],maxE; struct Edge { int nxt,to; }e[N<<1];
inline void Add(int u,int v) { e[++maxE].nxt=head[u]; head[u]=maxE; e[maxE].to=v; }
struct Node
{
	int v1,v2;
	int sum,len,fa;
}g[N]; int cnt,md;
struct Line
{
	int k,b;
}f[N];
int ans[N<<2];
#define lc (x<<1)
#define rc (x<<1|1)
void Renew(int x,int l,int r)
{
	ans[x]=0;
	if(l==r) return;
	int mid=(l+r)/2;
	Renew(lc,l,mid), Renew(rc,mid+1,r);
}
void UpDate(int u,int v,int l,int r,int x,int g)
{
	int mid=(l+r)/2;
	if(l>=u&&r<=v)
	{
		if(!ans[x]) { ans[x]=g; return; }
		int py=f[ans[x]].b+f[ans[x]].k*mid;
		int ny=f[g].b+f[g].k*mid;
		if(l==r)
		{
			if(ny>py) ans[x]=g;
			return;
		}
		if(f[g].k==f[ans[x]].k)
		{
			if(f[g].b>f[ans[x]].b) ans[x]=g;
		}
		else if(f[g].k>f[ans[x]].k)
		{
			if(ny>py)
			{
				UpDate(u,v,l,mid,lc,ans[x]);
				ans[x]=g;
			}
			else UpDate(u,v,mid+1,r,rc,g);
		}
		else
		{
			if(ny>py)
			{
				UpDate(u,v,mid+1,r,rc,ans[x]);
				ans[x]=g;
			}
			else UpDate(u,v,l,mid,lc,g);
		}
		return;
	}
	if(u<=mid) UpDate(u,v,l,mid,lc,g);
	if(v>mid) UpDate(u,v,mid+1,r,rc,g);
}
inline int Ask(int pos,int l,int r,int x)
{
	int tt=0;
	if(ans[x]) tt=ans[x];
	if(l==r) return tt;
	int mid=(l+r)/2,g=0;
	if(pos<=mid) g=Ask(pos,l,mid,lc);
	else g=Ask(pos,mid+1,r,rc);
	if(!tt) return g;
	if(!g) return tt;
	int ty=f[tt].b+f[tt].k*pos;
	int gy=f[g].b+f[g].k*pos;
	if(gy>ty) return g;
	else if(gy<ty) return tt;
	else return min(g,tt);
}
#undef lc
#undef rc
void FindRoot(int x,int fa,int S)
{
	siz[x]=1;
	int cs=-INF;
	for(ri int i=head[x];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(v==fa || book[v]) continue;
		FindRoot(v,x,S);
		siz[x]+=siz[v];
		cs=max(cs,siz[v]);
	}
	cs=max(cs,S-siz[x]);
	if(cs<maxx) maxx=cs, root=x;
}
void DFS(int x,int fa,int v1,int v2,int sum,int gg)
{
	d[x]=d[fa]+1, md=max(md,d[x]);
	int sz=0;
	for(ri int i=head[x];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(v==fa || book[v]) continue;
		DFS(v,x,v1+sum+a[v],v2+d[x]*a[v],sum+a[v],gg);
		sz++;
	}
	if(sz) return;
	g[++cnt]=(Node){v1,v2,sum-a[root],d[x],gg};
	f[cnt]=(Line){g[cnt].sum,g[cnt].v2};
}
inline void GetAns(int x)
{
	cnt=md=0, d[x]=1;
	for(ri int i=head[x];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(book[v]) continue;
		DFS(v,x,a[x]*2+a[v],a[v],a[x]+a[v],v);
	}
	if(!cnt)
	{
		Ans=max(Ans,a[x]);
		return;
	}
	g[cnt+1]=(Node){0,0,0,0,0};
	f[cnt+1]=(Line){0,0};
	for(ri int i=1;i<=cnt;i++) Ans=max(Ans,max(g[i].v1,g[i].v2+a[root]));
	Renew(1,1,md);
	for(ri int i=1;i<=cnt;)
	{
		int now=i;
		while(now<=cnt && g[now].fa==g[i].fa)
		{
			int gg=Ask(g[now].len,1,md,1);
			Ans=max(Ans,f[gg].b+f[gg].k*g[now].len+g[now].v1);
			now++;
		}
		now=i;
		while(now<=cnt && g[now].fa==g[i].fa) UpDate(1,n,1,n,1,now), now++;
		i=now;
	}
	Renew(1,1,md);
	for(ri int i=cnt;i;)
	{
		int now=i;
		while(now && g[now].fa==g[i].fa)
		{
			int gg=Ask(g[now].len,1,md,1);
			Ans=max(Ans,f[gg].b+f[gg].k*g[now].len+g[now].v1);
			now--;
		}
		now=i;
		while(now && g[now].fa==g[i].fa) UpDate(1,n,1,n,1,now), now--;
		i=now;
	}
}
void Solve(int x)
{
	book[x]=1;
	GetAns(x);
	for(ri int i=head[x];i;i=e[i].nxt)
	{
		int v=e[i].to;
		if(book[v]) continue;
		maxx=INF, root=0;
		FindRoot(v,x,siz[v]>siz[x]?n-siz[x]:siz[v]);
		Solve(root);
	}
}
signed main()
{
	n=read();
	for(ri int i=1;i<n;i++)
	{
		int u,v;
		u=read(), v=read();
		Add(u,v), Add(v,u);
	}
	for(ri int i=1;i<=n;i++) a[i]=read();
	maxx=INF, root=0, FindRoot(1,0,n);
	Solve(root);
	printf("%lld\n",Ans);
	return 0;
}
2021/2/25 13:49
加载中...