蜜汁dfs
  • 板块学术版
  • 楼主天命之路
  • 当前回复0
  • 已保存回复0
  • 发布时间2021/3/25 23:16
  • 上次更新2023/11/5 01:36:55
查看原帖
蜜汁dfs
226435
天命之路楼主2021/3/25 23:16

扩完栈后,本地还是运行错误,求调。

数据:n=1e5 的链

#include<bits/stdc++.h>
#pragma comment(linker, "/STACK:1024000000,1024000000") 
using namespace std;
const int N=1e6+5,M=2e6+5;
int fir[N],nxt[M],to[M],w[M],ect=0;
inline void addedge(int u1,int v1,int w1)
{
	nxt[++ect]=fir[u1];fir[u1]=ect;to[ect]=v1;w[ect]=w1;
}
int blak[N],whit[N],n;
int sze[N],pos,maxx,sgn[N],bk[N],dfstime;
typedef long long ll;
ll ans=0;
#define Edge(x) for(int e=fir[x],y;y=to[e],e;e=nxt[e])
inline void getroot(int x,int S,int fa,int dep)
{
//	printf("get:%d %d\n",++dfstime,dep); 
//	assert(!bk[x]);bk[x]=1;
	sze[x]=1;
	int cnt=0;
	Edge(x)
	{
		if(y==fa||sgn[y]) continue;
		getroot(y,S,x,dep+1);
		sze[x]+=sze[y];
		cnt=max(cnt,sze[y]);
	}
	cnt=max(cnt,S-sze[x]);
	if(cnt<maxx) maxx=cnt,pos=x;
}
ll h[N];
int total=0;
inline void dfs(int x,int fa,int rt)
{
//	printf("dfs:%d %d %d\n",rt,x,fa);
	h[++total]=x;
	Edge(x)
	{
		if(y==fa||sgn[y]) continue;
		blak[y]=blak[x]+(w[e]==0);
		whit[y]=whit[x]+(w[e]==1);
		
		
		dfs(y,x,rt);
	}
}

inline ll calc(int x,int val)
{
	blak[x]=val==0;
	whit[x]=val==1;
	total=0;
	dfs(x,0,x);
//	sort(h+1,h+total+1);
//	if(val==0&&x==2)
//	{
//		printf("h:\n");
//		for(int i=1;i<=total;i++)
//		printf("%lld ",h[i]);
//		printf("\n");
//	}
//	int cnt=0;ll res=0;
//	for(int i=1;i<=n;i++)
//	{
//		if(h[i]!=h[i-1])
//		{
//			ans+=1ll*cnt*(cnt-1)/2;
//			cnt=0;
//		}
//		else cnt++;
//	}
//	if(val==0&&x==2) printf("%d %d %d %d\n",blak[3],whit[3],blak[7],whit[7]);
	int cnt=0;
	for(int i=1;i<=total;i++)
	if(blak[h[i]]==whit[h[i]]&&h[i]!=x) cnt++;
//	printf("=========================\n");
//	printf("cnt:%d\n",cnt);
	return 1ll*cnt*(cnt-1)/2;
}
inline void Divide(int root)
{
	printf("root:%d\n",root); 
	ans=ans+calc(root,-1);
	sgn[root]=1;
	Edge(root)
	{
		if(sgn[y]) continue;
		pos=0;maxx=INT_MAX;
		ans-=calc(y,w[e]);
		getroot(y,sze[y],0,0);
		Divide(pos);
	}
}
int main()
{
	freopen("P3085_2.in","r",stdin);
	scanf("%d",&n);
	for(int i=1,u,v,w;i<n;i++)
	{
		scanf("%d%d%d",&u,&v,&w);
		addedge(u,v,w);addedge(v,u,w);
	}
	
	pos=0;maxx=INT_MAX;
	getroot(1,n,0,0);//exit(0); 
	Divide(pos);
	printf("%lld\n",ans);
}
2021/3/25 23:16
加载中...