通过了,但是极度不理解(点分治),求助
查看原帖
通过了,但是极度不理解(点分治),求助
642173
KarmaticEnding楼主2025/6/24 13:31

我用的思路是虚树+点分治。

这是一份 TLE 80pts\text{TLE }80\text{pts} 的代码:

#include<cstdio>
#include<algorithm>
inline void read(int &x){
	x=0;
	char c=getchar();
	while(c<'0'||c>'9')
		c=getchar();
	while(c>='0'&&c<='9')
		x=(x<<3)+(x<<1)+(c&15),c=getchar();
}
const int MAXN=1e6+7;
struct Edge{int v,nxt;}e[MAXN<<1];
int n,Q,dfn[MAXN],tim,dep[MAXN],fa[MAXN][21],pc,p[MAXN],h[MAXN],idx=1,stk[MAXN],tp,gvt,sizgvt=1e9,blk,siz[MAXN],tsiz,tdl,tdr,usd[MAXN],tu,ansl,ansr;
long long tdis,ans;
bool isq[MAXN],vis[MAXN];
inline void add(int u,int v){
	e[++idx]={v,h[u]},h[u]=idx;
}
void DFS_lca(int u){
	dep[u]=dep[fa[u][0]]+1;
	for(int i=1;i<21;++i)
		fa[u][i]=fa[fa[u][i-1]][i-1];
	dfn[u]=++tim;
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(!dfn[v])
			fa[v][0]=u,DFS_lca(v);
}
inline int LCA(int u,int v){
	if(dep[u]<dep[v])
		std::swap(u,v);
	for(int i=20;~i;--i)
		if(dep[fa[u][i]]>=dep[v])
			u=fa[u][i];
	if(u==v)
		return v;
	for(int i=20;~i;--i)
		if(fa[u][i]!=fa[v][i])
			u=fa[u][i],v=fa[v][i];
	return fa[u][0];
}
void DFS_gvt(int u,int ff){
	siz[u]=1;
	int mx=0;
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(v!=ff&&!vis[v])
			DFS_gvt(v,u),siz[u]+=siz[v],mx=std::max(mx,siz[v]);
	mx=std::max(mx,blk-siz[u]);
	if(mx<sizgvt)
		gvt=u,sizgvt=mx;
}
void DFS_dis(int u,int ff,int ds){
	if(isq[u])
		++tsiz,tdis+=ds,tdl=std::min(tdl,ds),tdr=std::max(tdr,ds);
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(v!=ff&&!vis[v])
			DFS_dis(v,u,ds+abs(dep[u]-dep[v]));
}
void DFS(int u){
	vis[u]=true,usd[++tu]=u;
	int gsiz=isq[u],tblk=blk,gdl=isq[u]?0:1e9,gdr=0;
	long long gdis=0;
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(!vis[v])
			tsiz=0,tdl=1e9,tdr=0,tdis=0,
			DFS_dis(v,u,abs(dep[u]-dep[v])),
			ans+=tsiz*gdis+gsiz*tdis,ansl=std::min(ansl,gdl+tdl),ansr=std::max(ansr,gdr+tdr),
			gsiz+=tsiz,gdis+=tdis,gdl=std::min(gdl,tdl),gdr=std::max(gdr,tdr);
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(!vis[v])
			gvt=0,sizgvt=1e9,blk=siz[v]>siz[u]?tblk-siz[u]:siz[v],DFS_gvt(v,u),DFS(gvt);
}
int main(){
	read(n);
	for(int i=1,u,v;i<n;++i)
		read(u),read(v),add(u,v),add(v,u);
	DFS_lca(1),idx=1;
	for(int i=1;i<=n;++i)
		h[i]=0;
	read(Q);
	while(Q--){
		read(pc),ans=0,ansl=1e9,ansr=-1;
		for(int i=1;i<=pc;++i)
			read(p[i]),isq[p[i]]=true;
		std::sort(p+1,p+pc+1,[&](const int x,const int y){return dfn[x]<dfn[y];}),stk[tp=1]=p[1];
		for(int i=2,lc;i<=pc;++i){
			lc=LCA(p[i],stk[tp]);
			while(dep[lc]<dep[stk[tp-1]])
				add(stk[tp-1],stk[tp]),add(stk[tp],stk[tp-1]),--tp;
			if(lc!=stk[tp]){
				add(lc,stk[tp]),add(stk[tp],lc);
				if(lc!=stk[tp-1])
					stk[tp]=lc;
				else
					--tp;
			}
			stk[++tp]=p[i];
		}
		for(int i=1;i<tp;++i)
			add(stk[i],stk[i+1]),add(stk[i+1],stk[i]);
		gvt=0,sizgvt=1e9,blk=pc,DFS_gvt(stk[1],0),DFS(gvt);
		for(int i=1;i<=tu;++i)
			isq[usd[i]]=false,vis[usd[i]]=false,h[usd[i]]=0;
		idx=1,tu=0;
		printf("%lld %d %d\n",ans,ansl,ansr);
	}
	return 0;
}

这是一份 AC 100pts\text{AC } 100\text{pts} 的代码:

#include<cstdio>
#include<algorithm>
inline void read(int &x){
	x=0;
	char c=getchar();
	while(c<'0'||c>'9')
		c=getchar();
	while(c>='0'&&c<='9')
		x=(x<<3)+(x<<1)+(c&15),c=getchar();
}
const int MAXN=1e6+7;
struct Edge{int v,nxt;}e[MAXN<<1];
int n,Q,dfn[MAXN],tim,dep[MAXN],fa[MAXN][21],pc,p[MAXN],h[MAXN],idx=1,stk[MAXN],tp,gvt,sizgvt=1e9,blk,siz[MAXN],tsiz,tdl,tdr,usd[MAXN],tu,ansl,ansr;
long long tdis,ans;
bool isq[MAXN],vis[MAXN];
inline void add(int u,int v){
	e[++idx]={v,h[u]},h[u]=idx;
}
void DFS_lca(int u){
	dep[u]=dep[fa[u][0]]+1;
	for(int i=1;i<21;++i)
		fa[u][i]=fa[fa[u][i-1]][i-1];
	dfn[u]=++tim;
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(!dfn[v])
			fa[v][0]=u,DFS_lca(v);
}
inline int LCA(int u,int v){
	if(dep[u]<dep[v])
		std::swap(u,v);
	for(int i=20;~i;--i)
		if(dep[fa[u][i]]>=dep[v])
			u=fa[u][i];
	if(u==v)
		return v;
	for(int i=20;~i;--i)
		if(fa[u][i]!=fa[v][i])
			u=fa[u][i],v=fa[v][i];
	return fa[u][0];
}
void DFS_gvt(int u,int ff){
	siz[u]=1;
	int mx=0;
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(v!=ff&&!vis[v])
			DFS_gvt(v,u),siz[u]+=siz[v],mx=std::max(mx,siz[v]);
	mx=std::max(mx,blk-siz[u]);
	if(mx<sizgvt)
		gvt=u,sizgvt=mx;
}
void DFS_dis(int u,int ff,int ds){
	if(isq[u])
		++tsiz,tdis+=ds,tdl=std::min(tdl,ds),tdr=std::max(tdr,ds);
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(v!=ff&&!vis[v])
			DFS_dis(v,u,ds+abs(dep[u]-dep[v]));
}
void DFS(int u){
	vis[u]=true,usd[++tu]=u;
	int gsiz=isq[u],tblk=blk,gdl=isq[u]?0:1e9,gdr=0;
	long long gdis=0;
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(!vis[v])
			tsiz=0,tdl=1e9,tdr=0,tdis=0,
			DFS_dis(v,u,abs(dep[u]-dep[v])),
			ans+=tsiz*gdis+gsiz*tdis,ansl=std::min(ansl,gdl+tdl),ansr=std::max(ansr,gdr+tdr),
			gsiz+=tsiz,gdis+=tdis,gdl=std::min(gdl,tdl),gdr=std::max(gdr,tdr);
	for(int i=h[u],v=e[i].v;i;v=e[i=e[i].nxt].v)
		if(!vis[v])
			gvt=0,sizgvt=1e9,blk=siz[v],DFS_gvt(v,u),DFS(gvt);
}
int main(){
	read(n);
	for(int i=1,u,v;i<n;++i)
		read(u),read(v),add(u,v),add(v,u);
	DFS_lca(1),idx=1;
	for(int i=1;i<=n;++i)
		h[i]=0;
	read(Q);
	while(Q--){
		read(pc),ans=0,ansl=1e9,ansr=-1;
		for(int i=1;i<=pc;++i)
			read(p[i]),isq[p[i]]=true;
		std::sort(p+1,p+pc+1,[&](const int x,const int y){return dfn[x]<dfn[y];}),stk[tp=1]=p[1];
		for(int i=2,lc;i<=pc;++i){
			lc=LCA(p[i],stk[tp]);
			while(dep[lc]<dep[stk[tp-1]])
				add(stk[tp-1],stk[tp]),add(stk[tp],stk[tp-1]),--tp;
			if(lc!=stk[tp]){
				add(lc,stk[tp]),add(stk[tp],lc);
				if(lc!=stk[tp-1])
					stk[tp]=lc;
				else
					--tp;
			}
			stk[++tp]=p[i];
		}
		for(int i=1;i<tp;++i)
			add(stk[i],stk[i+1]),add(stk[i+1],stk[i]);
		gvt=0,sizgvt=1e9,blk=pc,DFS_gvt(stk[1],0),DFS(gvt);
		for(int i=1;i<=tu;++i)
			isq[usd[i]]=false,vis[usd[i]]=false,h[usd[i]]=0;
		idx=1,tu=0;
		printf("%lld %d %d\n",ans,ansl,ansr);
	}
	return 0;
}

容易发现二者的唯一区别在于第 7070 行,DFS 函数末尾,对 blk 大小的处理。

我认为第一份代码的 blk 大小处理是正确的,原因在于:DFS() 中传入的是连通块的重心,但是 DFS_gvt() 中传入的是连通块中的任意点,所以 siz[v]>siz[u] 这种情况有可能发生,这个时候就要特判处理一下。

所以是为什么呢,有没有大佬能够解释一下?

2025/6/24 13:31
加载中...