求助
查看原帖
求助
289056
北射天狼楼主2022/12/10 16:00

WA * 4 TLE * 2

#include <bits/stdc++.h>
using namespace std;
const int N = 100100;
int b[N];
int head[N],cnt;
struct node{
	int v,w,next;
}tree[N<<1];
struct zzp{
	int dep;
	int target;
	friend bool operator < (const zzp a,const zzp b){
		return a.dep < b.dep;
	}
}a[N];
bool ok[N],vis[N];
int f[N],size[N];
int value[N];
int n,m;
int tot;
int root;
void getroot(int u,int father,int total)
{
	size[u]=f[u]=1;
	for (int i=head[u];i;i=tree[i].next){
		int v = tree[i].v;
		if (v == father||vis[v])
		    continue;
		getroot(v,u,total);
		size[u]+=size[v];
		f[u]=max(f[u],size[v]);
	}
	f[u]=max(f[u],total - size[u]);
	if (f[u] < f[root])
	{
		root = u;
	}
}
void getdis(int u,int father,int val,int color)
{
	a[tot].target = u;
	b[u] = color;
	a[++tot].dep = val;
    for (int i=head[u];i;i=tree[i].next)
	{
		int v = tree[i].v;
		if (v == father || vis[v])
		    continue;
		getdis(v,u,val + tree[i].w,color);
	}	
}
void getsum(int u)
{
	tot = 1;
	a[1].dep = 0;
	a[1].target = u;
	b[u] = u;
	for (int i=head[u];i;i=tree[i].next){
		int v = tree[i].v;
		if (vis[v])
		    continue;
		getdis(v,u,tree[i].w,v);
	}
    sort(a+1,a+tot+1);
    for (int i=1;i<=m;i++)
    {
    	int L = 1,R = tot;
    	if (ok[i])    continue;
    	while (L < R)
    	{
    		if (a[L].dep + a[R].dep < value[i])
    		    L++;
    		else if (a[L].dep+a[R].dep > value[i]){
    			R--;
			} else if (b[a[L].target] == b[a[R].target])
			{
				if (a[R].dep == a[R-1].dep)
				    R--;
				else L++;
			}
			else {
				ok[i] = true;
				break;
			}
		}
	}
}
void solve(int u)
{
	vis[u] = 1;
	getsum(u);
	for (int i=head[u];i;i=tree[i].next){
		int v = tree[i].v;
		if (vis[v])
		    continue;
		root = 0;
		getroot(v,-1,size[u]);
		solve(v);
	}
}
void add(int u,int v,int w)
{
	tree[++cnt].next = head[u];
	tree[cnt].v = v;
	tree[cnt].w = w;
	head[u] = cnt;
}
int main()
{
	scanf("%d%d",&n,&m);
	for (int i=1;i<n;i++){
		int u,v,w;
		scanf("%d%d%d",&u,&v,&w);
		add(u,v,w);
		add(v,u,w);
	}
	//cout<<"*";
	for (int i=1;i<=m;i++){
	    scanf("%d",&value[i]);
	    if (value[i] == 0)
	        ok[i]=1;
	}
	root = 0;f[0]=0x3f3f3f3f;
	getroot(1,-1,n);
	//cout<<"*"; 
	solve(root);
	for (int i=1;i<=m;i++)
	{
		if (!ok[i])
		    puts("NAY");
		else 
	        puts("AYE");
	}
    return 0;	
}
2022/12/10 16:00
加载中...