求助20分
查看原帖
求助20分
756336
李承轩楼主2022/12/10 10:42
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e4+10;
int n,m,t,l,r,a[N],atot,btot,ctot;
int head[N],nxt[N<<1],to[N<<1],w[N<<1],idx;
int d[N],f[N][20],dis[N][20];
bool flag=1,b[N],tag[N];
int ans,tim[N],ned[N];
pair<int,int> h[N];
queue<int> q;
void add(int x,int y,int z)
{
	to[++idx]=y,w[idx]=z,nxt[idx]=head[x],head[x]=idx;
}
void bfs()
{
	q.push(1);
	d[1]=1;
	while(q.size())
	{
		int x=q.front();
		q.pop();
		for(int i=head[x];i;i=nxt[i])
		{
			int y=to[i];
			if(d[y])continue;
			d[y]=d[x]+1;
			f[y][0]=x,dis[y][0]=w[i];
			for(int j=1;j<=t;j++)
			{
				f[y][j]=f[f[y][j-1]][j-1];
				dis[y][j]=dis[y][j-1]+dis[f[y][j-1]][j-1];
			}
			q.push(y);
		}
	}
}
bool dfs(int x)
{
	bool leafson=0;
	if(b[x])return 1;
	for(int i=head[x];i;i=nxt[i])
	{
		int y=to[i];
		if(d[y]<d[x])continue;
		leafson=1;
		if(!dfs(y))return 0;
	}
	if(!leafson)return 0;
	return 1;
}
bool check(int mid)
{
	memset(b,0,sizeof b);
	memset(tim,0,sizeof tim);
	memset(ned,0,sizeof ned);
	memset(h,0,sizeof(h));
	memset(tag,0,sizeof(tag));
	atot=0,btot=0,ctot=0;
	for(int i=1;i<=m;i++)
	{
		int x=a[i],cnt=0;
		for(int j=t;j>=0;j--)
			if(f[x][i]>1&&cnt+dis[x][j]<=mid)
				cnt+=dis[x][j],x=f[x][j];
		if(f[x][0]==1&&cnt+dis[x][0]<=mid)
			h[++ctot]=make_pair(mid-cnt-dis[x][0],x);
		else b[x]=1;
	}
	for(int i=head[1];i;i=nxt[i])
		if(!dfs(to[i]))
			tag[to[i]]=1;
	sort(h+1,h+ctot+1);
	for(int i=1;i<=ctot;i++)
	{
		if(tag[h[i].second]&&h[i].first<dis[h[i].second][0])
			tag[h[i].second]=0;
		else 
			tim[++atot]=h[i].first;
	}
	for(int i=head[1];i;i=nxt[i])
		if(tag[to[i]])
			ned[++btot]=dis[to[i]][0];
	if(atot<btot)return 0;
	sort(tim+1,tim+atot+1),sort(ned+1,ned+btot+1);
	int i=1,j=1;
	while(i<=btot&&j<=atot)
	{
		if(tim[j]>=ned[i])i++;
		j++;
	}
	if(i>btot)return 1;
	return 0;
}
signed main()
{
	scanf("%lld",&n);
	t=log2(n)+1;
	for(int i=1,x,y,z;i<n;i++)
	{
		scanf("%lld%lld%lld",&x,&y,&z);
		add(x,y,z),add(y,x,z);
		r+=z;
	}
	bfs();
	scanf("%lld",&m);
	for(int i=1;i<=m;i++)
		scanf("%lld",&a[i]);
	while(l<=r)
	{
		int mid=(l+r)>>1;
		if(check(mid))r=mid-1,flag=0;
		else l=mid+1;
	}
	if(!flag)printf("-1");
	else cout<<l;
	return 0;
}
2022/12/10 10:42
加载中...