我用的思路是虚树+点分治。
这是一份 TLE 80pts 的代码:
#include<cstdio>
#include<algorithm>
inline void read(int &x){
x=0;
char c=getchar();
while(c<'0'||c>'9')
c=getchar();
while(c>='0'&&c<='9')
x=(x<<3)+(x<<1)+(c&15),c=getchar();
}
const int MAXN=1e6+7;
struct Edge{int v,nxt;}e[MAXN<<1];
int n,Q,dfn[MAXN],tim,dep[MAXN],fa[MAXN][21],pc,p[MAXN],h[MAXN],idx=1,stk[MAXN],tp,gvt,sizgvt=1e9,blk,siz[MAXN],tsiz,tdl,tdr,usd[MAXN],tu,ansl,ansr;
long long tdis,ans;
bool isq[MAXN],vis[MAXN];
inline void add(int u,int v){
e[++idx]={v,h[u]},h[u]=idx;
}
void DFS_lca(int u){
dep[u]=dep[fa[u][0]]+1;
for(int i=1;i<21;++i)
fa[u][i]=fa[fa[u][i-1]][i-1];
dfn[u]=++tim;
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(!dfn[v])
fa[v][0]=u,DFS_lca(v);
}
inline int LCA(int u,int v){
if(dep[u]<dep[v])
std::swap(u,v);
for(int i=20;~i;--i)
if(dep[fa[u][i]]>=dep[v])
u=fa[u][i];
if(u==v)
return v;
for(int i=20;~i;--i)
if(fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
void DFS_gvt(int u,int ff){
siz[u]=1;
int mx=0;
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(v!=ff&&!vis[v])
DFS_gvt(v,u),siz[u]+=siz[v],mx=std::max(mx,siz[v]);
mx=std::max(mx,blk-siz[u]);
if(mx<sizgvt)
gvt=u,sizgvt=mx;
}
void DFS_dis(int u,int ff,int ds){
if(isq[u])
++tsiz,tdis+=ds,tdl=std::min(tdl,ds),tdr=std::max(tdr,ds);
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(v!=ff&&!vis[v])
DFS_dis(v,u,ds+abs(dep[u]-dep[v]));
}
void DFS(int u){
vis[u]=true,usd[++tu]=u;
int gsiz=isq[u],tblk=blk,gdl=isq[u]?0:1e9,gdr=0;
long long gdis=0;
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(!vis[v])
tsiz=0,tdl=1e9,tdr=0,tdis=0,
DFS_dis(v,u,abs(dep[u]-dep[v])),
ans+=tsiz*gdis+gsiz*tdis,ansl=std::min(ansl,gdl+tdl),ansr=std::max(ansr,gdr+tdr),
gsiz+=tsiz,gdis+=tdis,gdl=std::min(gdl,tdl),gdr=std::max(gdr,tdr);
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(!vis[v])
gvt=0,sizgvt=1e9,blk=siz[v]>siz[u]?tblk-siz[u]:siz[v],DFS_gvt(v,u),DFS(gvt);
}
int main(){
read(n);
for(int i=1,u,v;i<n;++i)
read(u),read(v),add(u,v),add(v,u);
DFS_lca(1),idx=1;
for(int i=1;i<=n;++i)
h[i]=0;
read(Q);
while(Q--){
read(pc),ans=0,ansl=1e9,ansr=-1;
for(int i=1;i<=pc;++i)
read(p[i]),isq[p[i]]=true;
std::sort(p+1,p+pc+1,[&](const int x,const int y){return dfn[x]<dfn[y];}),stk[tp=1]=p[1];
for(int i=2,lc;i<=pc;++i){
lc=LCA(p[i],stk[tp]);
while(dep[lc]<dep[stk[tp-1]])
add(stk[tp-1],stk[tp]),add(stk[tp],stk[tp-1]),--tp;
if(lc!=stk[tp]){
add(lc,stk[tp]),add(stk[tp],lc);
if(lc!=stk[tp-1])
stk[tp]=lc;
else
--tp;
}
stk[++tp]=p[i];
}
for(int i=1;i<tp;++i)
add(stk[i],stk[i+1]),add(stk[i+1],stk[i]);
gvt=0,sizgvt=1e9,blk=pc,DFS_gvt(stk[1],0),DFS(gvt);
for(int i=1;i<=tu;++i)
isq[usd[i]]=false,vis[usd[i]]=false,h[usd[i]]=0;
idx=1,tu=0;
printf("%lld %d %d\n",ans,ansl,ansr);
}
return 0;
}
这是一份 AC 100pts 的代码:
#include<cstdio>
#include<algorithm>
inline void read(int &x){
x=0;
char c=getchar();
while(c<'0'||c>'9')
c=getchar();
while(c>='0'&&c<='9')
x=(x<<3)+(x<<1)+(c&15),c=getchar();
}
const int MAXN=1e6+7;
struct Edge{int v,nxt;}e[MAXN<<1];
int n,Q,dfn[MAXN],tim,dep[MAXN],fa[MAXN][21],pc,p[MAXN],h[MAXN],idx=1,stk[MAXN],tp,gvt,sizgvt=1e9,blk,siz[MAXN],tsiz,tdl,tdr,usd[MAXN],tu,ansl,ansr;
long long tdis,ans;
bool isq[MAXN],vis[MAXN];
inline void add(int u,int v){
e[++idx]={v,h[u]},h[u]=idx;
}
void DFS_lca(int u){
dep[u]=dep[fa[u][0]]+1;
for(int i=1;i<21;++i)
fa[u][i]=fa[fa[u][i-1]][i-1];
dfn[u]=++tim;
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(!dfn[v])
fa[v][0]=u,DFS_lca(v);
}
inline int LCA(int u,int v){
if(dep[u]<dep[v])
std::swap(u,v);
for(int i=20;~i;--i)
if(dep[fa[u][i]]>=dep[v])
u=fa[u][i];
if(u==v)
return v;
for(int i=20;~i;--i)
if(fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
void DFS_gvt(int u,int ff){
siz[u]=1;
int mx=0;
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(v!=ff&&!vis[v])
DFS_gvt(v,u),siz[u]+=siz[v],mx=std::max(mx,siz[v]);
mx=std::max(mx,blk-siz[u]);
if(mx<sizgvt)
gvt=u,sizgvt=mx;
}
void DFS_dis(int u,int ff,int ds){
if(isq[u])
++tsiz,tdis+=ds,tdl=std::min(tdl,ds),tdr=std::max(tdr,ds);
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(v!=ff&&!vis[v])
DFS_dis(v,u,ds+abs(dep[u]-dep[v]));
}
void DFS(int u){
vis[u]=true,usd[++tu]=u;
int gsiz=isq[u],tblk=blk,gdl=isq[u]?0:1e9,gdr=0;
long long gdis=0;
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(!vis[v])
tsiz=0,tdl=1e9,tdr=0,tdis=0,
DFS_dis(v,u,abs(dep[u]-dep[v])),
ans+=tsiz*gdis+gsiz*tdis,ansl=std::min(ansl,gdl+tdl),ansr=std::max(ansr,gdr+tdr),
gsiz+=tsiz,gdis+=tdis,gdl=std::min(gdl,tdl),gdr=std::max(gdr,tdr);
for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
if(!vis[v])
gvt=0,sizgvt=1e9,blk=siz[v],DFS_gvt(v,u),DFS(gvt);
}
int main(){
read(n);
for(int i=1,u,v;i<n;++i)
read(u),read(v),add(u,v),add(v,u);
DFS_lca(1),idx=1;
for(int i=1;i<=n;++i)
h[i]=0;
read(Q);
while(Q--){
read(pc),ans=0,ansl=1e9,ansr=-1;
for(int i=1;i<=pc;++i)
read(p[i]),isq[p[i]]=true;
std::sort(p+1,p+pc+1,[&](const int x,const int y){return dfn[x]<dfn[y];}),stk[tp=1]=p[1];
for(int i=2,lc;i<=pc;++i){
lc=LCA(p[i],stk[tp]);
while(dep[lc]<dep[stk[tp-1]])
add(stk[tp-1],stk[tp]),add(stk[tp],stk[tp-1]),--tp;
if(lc!=stk[tp]){
add(lc,stk[tp]),add(stk[tp],lc);
if(lc!=stk[tp-1])
stk[tp]=lc;
else
--tp;
}
stk[++tp]=p[i];
}
for(int i=1;i<tp;++i)
add(stk[i],stk[i+1]),add(stk[i+1],stk[i]);
gvt=0,sizgvt=1e9,blk=pc,DFS_gvt(stk[1],0),DFS(gvt);
for(int i=1;i<=tu;++i)
isq[usd[i]]=false,vis[usd[i]]=false,h[usd[i]]=0;
idx=1,tu=0;
printf("%lld %d %d\n",ans,ansl,ansr);
}
return 0;
}
容易发现二者的唯一区别在于第 70 行,DFS
函数末尾,对 blk
大小的处理。
我认为第一份代码的 blk
大小处理是正确的,原因在于:DFS()
中传入的是连通块的重心,但是 DFS_gvt()
中传入的是连通块中的任意点,所以 siz[v]>siz[u]
这种情况有可能发生,这个时候就要特判处理一下。
所以是为什么呢,有没有大佬能够解释一下?