咕,基本就是看了题解然后打出来的
改了比较久,但是RE,WA很多
(因为找不出错误改的尽量和题解一样
(但还是错的
Code:
#include<bits/stdc++.h>
using namespace std;
long long son[1000005],len[1000005],nxt[1000005],fir[1000005],ned[1000005],army[1000005],n,m,r=0,deep[1000005],num[1000005][20],ending[1000005],dis[1000005][20],oto=0,tot=0,oot=0,spe=0,ans;
struct data{long long last,p;}fre[1000005];
bool stay[1000005],need[1000005],f=false;
bool cmp(data a,data b){return a.last<b.last;}
long long read(){
long long sum=0,e=1;char c=getchar();
while(!(c>='0'&&c<='9')){if(c=='-')e=-1;c=getchar();}
while(c>='0'&&c<='9'){sum=sum*10+c-'0';c=getchar();}
return sum*e;
}
void add(long long x,long long y,long long z){son[++r]=y;len[r]=z;nxt[r]=fir[x];fir[x]=r;}
void dfson(long long x){
for(long long i=fir[x];i;i=nxt[i]){
if(deep[son[i]])continue;
deep[son[i]]=deep[x]+1;
num[son[i]][0]=x,dis[son[i]][0]=len[i];
for(long long j=1;j<=log2(n)+1;j++){
num[son[i]][j]=num[num[son[i]][j-1]][j-1];
dis[son[i]][j]=dis[son[i]][j-1]+dis[num[son[i]][j-1]][j-1];
}
dfson(son[i]);
}
}
bool dfs(long long x){
if(stay[x])return true;
bool flag=false;
for(long long i=fir[x];i;i=nxt[i]){
if(deep[son[i]]<deep[x])continue;
flag=true;
if(!dfs(son[i]))return false;
}
if(!flag)return false;
return true;
}
bool check(long long timline){
memset(stay,false,sizeof(stay));
memset(need,false,sizeof(need));
for(long long i=1;i<=tot;i++)fre[i].last=fre[i].p=0;
memset(ned,0,oot+1);
memset(ending,0,oto+1);
tot=oto=oot=0;
for(long long i=1;i<=m;i++){
long long x=army[i],sum=0;
for(long long j=log2(n)+1;j>=0;j--){
if(num[x][j]>1&&sum+dis[x][j]<=timline)sum+=dis[x][j],x=num[x][j];
if(num[x][0]==1&&sum+dis[x][0]<=timline)fre[++tot].last=timline-sum-dis[x][0],fre[tot].p=x;
else stay[x]=true;
}
}
for(long long i=fir[1];i;i=nxt[i])if(!dfs(son[i]))need[son[i]]=true;
sort(fre+1,fre+tot+1,cmp);
for(long long i=1;i<=tot;i++)
if(need[fre[i].p]&&fre[i].last<dis[fre[i].last][0])need[fre[i].p]=false;
else ending[++oto]=fre[i].last;
for(long long i=fir[1];i;i=nxt[i])if(need[son[i]])ned[++oot]=dis[son[i]][0];
if(oot>oto)return false;
sort(ending+1,ending+oto+1);sort(ned+1,ned+oot+1);
long long qe=1,qn=1;
while(qe<=oto&&qn<=oot)
if(ending[qe]>=ned[qn])qe++,qn++;
else qe++;
if(qn>oot)return true;
return false;
}
void find(){
long long L=1,R=spe,M;
while(L<=R){
M=(L+R)/2;
if(check(M))R=M-1,ans=M,f=true;
else L=M+1;
}
}
int main(){
n=read();
for(long long i=1;i<n;i++){
long long a=read(),b=read(),c=read();
add(a,b,c);add(b,a,c);spe+=c;
}
m=read();
for(long long i=1;i<=m;i++)army[i]=read();
// bfs();
deep[1]=1;
dfson(1);find();
if(!f)cout<<-1;
else cout<<ans;
return 0;
}