求助
查看原帖
求助
503792
Svemit楼主2022/12/9 13:53

T飞了30pts

#include<bits/stdc++.h>
using namespace std;
const int N=1e4+5;
int n,m,cnt,tot,s,root,k,ans;
int head[N],size[N],f[N],d[N],dep[N];
bool vis[N];
struct edge
{
	int nex,to,w;
}e[N<<1];

inline void add_edge(int u,int v,int w)
{
	e[++cnt].to=v;
	e[cnt].w=w;
	e[cnt].nex=head[u];
	head[u]=cnt;
}

inline void get_root(int u,int fa)
{
	size[u]=1;
	f[u]=0;
	for(int i=head[u];i;i=e[i].nex)
	{
		int v=e[i].to;
		if(v==fa||vis[v]) continue;
		get_root(v,u);
		size[u]+=size[v];
		f[u]=max(f[u],size[v]);
	}
	f[u]=max(f[u],s-f[u]);
	if(f[u]<f[root])
	  root=u;
}

inline void get_dep(int u,int fa)
{
	dep[++tot]=d[u];
	for(int i=head[u];i;i=e[i].nex)
	{
		int v=e[i].to,w=e[i].w;
		if(v==fa||vis[v]||d[u]+w>k) continue;
		d[v]=d[u]+w;
		get_dep(v,u);
	}
}

inline int get_sum(int u,int dis)
{
	d[u]=dis;
	tot=0;
	int sum=0;
	get_dep(u,0);
	sort(dep+1,dep+1+tot);
	int l=1,r=tot;
	while(l<r)
	{
		if(dep[l]+dep[r]<k) l++;
		else
		  if(dep[l]+dep[r]>k) r--;
		  else
		  {
		    if(dep[l]==dep[r])
		    {
			    sum+=(r-l+1)*(r-l)/2;
			    break;
		    }
		    int st=l,ed=r;
			while(dep[st]==dep[l])
			  st++;
			while(dep[ed]==dep[r])
			  ed--;
			 sum+=(st-l)*(r-ed);
			l=st;
			r=ed; 
		  }
	}
	return sum;
}

inline void solve(int u)
{
	vis[u]=true;
	ans+=get_sum(u,0);
	for(int i=head[u];i;i=e[i].nex)
	{
		int v=e[i].to,w=e[i].w;
		if(vis[v]) continue;
		ans-=get_sum(v,w);
		root=0;
		s=size[v];
		get_root(v,u);
		solve(v);
	}
}

int main()
{
	std::ios::sync_with_stdio(false);
	std::cin.tie(NULL);
	std::cout.tie(NULL);
	f[0]=0x3f3f3f3f;
	cin>>n>>m;
	for(int i=1;i<n;i++)
	{
		int u,v,w;
		cin>>u>>v>>w;
		add_edge(u,v,w);
		add_edge(v,u,w);
	}
	while(m--)
	{
		cin>>k;
		memset(vis,false,sizeof(vis));
		root=0;
		s=n;
		ans=0;
		get_root(1,0);
		solve(root);
		if(ans)
		  cout<<"AYE\n";
		else
		  cout<<"NAY\n";
	}
	return 0;
}

2022/12/9 13:53
加载中...