倍增:95pt
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+10;
int n,m;
int h[N],nx[N<<1],b[N<<1],v[N<<1],tt;
int f[N][21],w[N],d[N],c[N],dis[N],l=0,r=3e8+1,mid;
struct jgt
{
int x,y,jl,fu;
}q[N];
void zj(int x,int y,int z)
{
b[++tt]=y;
v[tt]=z;
nx[tt]=h[x];
h[x]=tt;
}
void csh(int x)
{
for(int i=h[x];i;i=nx[i])
{
int y=b[i],z=v[i];
if(f[x][0]==y)
continue;
f[y][0]=x;
w[y]=z;
d[y]=d[x]+1;
dis[y]=dis[x]+z;
csh(y);
}
}
int lca(int x,int y)
{
if(d[x]<d[y])
swap(x,y);
for(int i=20;i>=0;i--)
if(d[f[x][i]]>=d[y])
x=f[x][i];
if(x==y)
return x;
for(int i=20;i>=0;i--)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
void dfs(int x)
{
for(int i=h[x];i;i=nx[i])
{
int y=b[i];
if(f[x][0]==y)
continue;
dfs(y);
c[x]+=c[y];
}
}
bool pd()
{
int shu=0,maxjl=0;
bool pds=false;
memset(c,0,sizeof(c));
for(int i=1;i<=m;i++)
{
if(q[i].jl<=mid)
continue;
maxjl=max(maxjl,q[i].jl);
++shu;
int x=q[i].x,y=q[i].y,fa=q[i].fu;
++c[x];
++c[y];
c[fa]-=2;
}
if(shu==0)
return true;
dfs(1);
int bian=0;
for(int i=1;i<=n;i++)
if(c[i]==shu)
{
pds=true;
if(w[bian]<w[i])
bian=i;
}
if(!pds||maxjl-w[bian]>mid)
return false;
return true;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
zj(x,y,z);zj(y,x,z);
}
csh(1);
for(int j=1;j<=20;j++)
for(int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1];
for(int i=1;i<=m;i++)
{
int a,b;
scanf("%d%d",&a,&b);
int fa=lca(a,b);
q[i].x=a;
q[i].y=b;
q[i].fu=fa;
q[i].jl=dis[a]+dis[b]-2*dis[fa];
}
while(l<r)
{
mid=(l+r)>>1;
if(pd())
r=mid;
else
l=mid+1;
}
printf("%d",l);
return 0;
}
tarjan:80pt
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+10;
int n,m;
int h[N],nx[N<<1],b[N<<1],v[N<<1],tt;
int f[N],w[N],d[N],c[N],dis[N],l=0,r=3e8+1,mid;
bool pd[N];
vector<pair<int,int> > wt[N];
struct jgt
{
int x,y,jl,fu;
}q[N];
void zj(int x,int y,int z)
{
b[++tt]=y;
v[tt]=z;
nx[tt]=h[x];
h[x]=tt;
}
void csh(int x,int fa)
{
for(int i=h[x];i;i=nx[i])
{
int y=b[i],z=v[i];
if(y==fa)
continue;
w[y]=z;
dis[y]=dis[x]+z;
csh(y,x);
}
}
int cx(int x)
{
if(x==f[x])
return x;
return cx(f[x]);
}
void tarjan(int x)
{
pd[x]=1;f[x]=x;
for(int i=h[x];i;i=nx[i])
{
int y=b[i];
if(pd[y])
continue;
tarjan(y);
f[y]=x;
}
int len=wt[x].size();
for(int i=0;i<len;i++)
{
int y=wt[x][i].first,id=wt[x][i].second;
if(!pd[y]||q[id].jl)
continue;
q[id].fu=cx(y);
q[id].x=x;
q[id].y=y;
q[id].jl=dis[x]+dis[y]-2*dis[q[id].fu];
}
}
void dfs(int x,int fa)
{
for(int i=h[x];i;i=nx[i])
{
int y=b[i];
if(y==fa)
continue;
dfs(y,x);
c[x]+=c[y];
}
}
bool pdch()
{
int shu=0,maxjl=0;
bool pds=false;
memset(c,0,sizeof(c));
for(int i=1;i<=m;i++)
{
if(q[i].jl<=mid)
continue;
maxjl=max(maxjl,q[i].jl);
++shu;
int x=q[i].x,y=q[i].y,fa=q[i].fu;
++c[x];
++c[y];
c[fa]-=2;
}
if(shu==0)
return true;
dfs(1,0);
int bian=0;
for(int i=1;i<=n;i++)
if(c[i]==shu)
{
pds=true;
if(w[bian]<w[i])
bian=i;
}
if(!pds||maxjl-w[bian]>mid)
return false;
return true;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
zj(x,y,z);zj(y,x,z);
}
csh(1,0);
for(int i=1;i<=m;i++)
{
int a,b;
scanf("%d%d",&a,&b);
wt[a].push_back({b,i});
wt[b].push_back({a,i});
}
tarjan(1);
while(l<r)
{
mid=(l+r)>>1;
if(pdch())
r=mid;
else
l=mid+1;
}
printf("%d",l);
return 0;
}