RT,LCA写T了两个点,尽力卡了,然鹅并无甚用,请求支援。
#include<bits/stdc++.h>
using namespace std;
int n,m,s,dp[300005][20],d[300005],lg[300005],dis[300005],uuu[300005],vvv[300005],diff[300005],maxnn,maxn;
struct node{
int v,w;
};
vector<node>nbr[300005];
int max(int a,int b){
return (a>b)?a:b;
}
inline int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-48;ch=getchar();}
return x*f;
}
void before(int cur,int fa){
d[cur]=d[fa]+1;
dp[cur][0]=fa;
for(register int i=1;i<=lg[n];i++){
dp[cur][i]=dp[dp[cur][i-1]][i-1];
}
for(register int i=0;i<nbr[cur].size();i++){
int son=nbr[cur][i].v;
if(son==fa)continue;
dis[son]=dis[cur]+nbr[cur][i].w;
before(son,cur);
}
return;
}
inline int LCA(int x,int y/*under*/){
if(d[x]>d[y])swap(x,y);
for(register int i=lg[n];i>=0;i--){
if(d[x]<=d[y]-(1<<i))y=dp[y][i];
}
if(x==y)return x;
for(register int i=lg[n];i>=0;i--){
if(dp[x][i]!=dp[y][i]){
x=dp[x][i];
y=dp[y][i];
}
}
return dp[x][0];
}
inline void dfs(int cur,int fa){
for(register int i=0;i<nbr[cur].size();i++){
int son=nbr[cur][i].v;
if(son==fa)continue;
dfs(son,cur);
diff[cur]+=diff[son];
}
}
inline void Get(int cur,int fa,int cnt){
for(register int i=0;i<nbr[cur].size();i++){
int son=nbr[cur][i].v;
if(son==fa)continue;
if(diff[son]==cnt){
maxnn=max(maxnn,nbr[cur][i].w);
}
Get(son,cur,cnt);
}
}
inline void clean(){
for(register int i=1;i<=n;i++)diff[i]=0;
}
inline bool check(int x){
clean();
int cnt=0;
for(register int i=1;i<=m;i++){
int l=LCA(uuu[i],vvv[i]);
int len=dis[uuu[i]]+dis[vvv[i]]-(dis[l]<<1);
if(len>x){
cnt++;
diff[uuu[i]]++;
diff[vvv[i]]++;
diff[l]-=2;
}
}
if(maxn<=x)return 1;
dfs(1,0);
maxnn=0;
Get(1,0,cnt);
if(maxn-maxnn<=x)return 1;
return 0;
}
int main(){
n=read();m=read();
lg[0]=-1;
for(register int i=1;i<n;i++){
int u=read(),v=read(),w=read();
node bag1={v,w},bag2={u,w};
nbr[u].push_back(bag1);
nbr[v].push_back(bag2);
lg[i]=lg[i>>1]+1;
}
lg[n]=lg[n>>1]+1;
before(1,0);
for(register int i=1;i<=m;i++){
uuu[i]=read();
vvv[i]=read();
}
for(int i=1;i<=m;i++){
int l=LCA(uuu[i],vvv[i]);
int len=dis[uuu[i]]+dis[vvv[i]]-(dis[l]<<1);
maxn=max(maxn,len);
}
int lt=-1,rt=maxn+1;
while(lt+1<rt){
int mid=lt+rt>>1;
if(check(mid))rt=mid;
else lt=mid;
}
cout<<rt;
return 0;
}