70pts的LCA求助!
查看原帖
70pts的LCA求助!
95072
wudiss8楼主2020/9/26 17:25

按照第一篇题解的思路来写的代码,挂了#1,#9~13这些点,求大佬帮忙看看错在哪

#include<bits/stdc++.h>
using namespace std;
const int SIZE=300000;
int tot,next[610001],poi[610001],to[610001];
int tot1,next1[610001],poi1[610001],to1[610001];
int tot2,next2[610001],poi2[610001],to2[610001];
int fa[610001][21],dep[610001];
int b1[610001],b2[610001],w[610001],js[610001],dis[610001];
int s[610001],t[610001];
int ans[610001];
inline void addt(int x,int y){
	tot++;
	next[tot]=poi[x];poi[x]=tot;to[tot]=y;
}
inline void add1(int x,int y){
	tot1++;
	next1[tot1]=poi1[x];poi1[x]=tot1;to1[tot1]=y;
}
inline void add2(int x,int y){
	tot2++;
	next2[tot2]=poi2[x];poi2[x]=tot2;to2[tot2]=y;
}
inline void dfs(int x,int fat){
	fa[x][0]=fat;
	dep[x]=dep[fat]+1;
	for(register int i=1;i<=20;i++)
	fa[x][i]=fa[fa[x][i-1]][i-1];
	for(register int e=poi[x];e;e=next[e]){
		if(to[e]==fat)continue;
		dfs(to[e],x);
	}
}
inline int glca(int x,int y){
	if(dep[y]>dep[x])swap(x,y);
	for(register int i=20;i>=0;i--){
		if(dep[fa[x][i]]>=dep[y]){
			x=fa[x][i];
		}
	}
	if(x==y)return x;
	for(register int i=20;i>=0;i--){
		if(fa[x][i]!=fa[y][i]){
			x=fa[x][i];
			y=fa[y][i];
		}
	}
	return fa[x][0];
}
inline void dfs2(int x,int fat){
	int t1=b1[dep[x]+w[x]],t2=b2[w[x]-dep[x]+SIZE];
	for(register int e=poi[x];e;e=next[e]){
		if(to[e]==fat)continue;
		dfs2(to[e],x);
	}
	b1[dep[x]]=b1[dep[x]]+js[x];
	for(register int e=poi1[x];e;e=next1[e]){
		b2[dis[to1[e]]-dep[t[to1[e]]]+SIZE]++;
	}
	ans[x]=ans[x]+b1[dep[x]+w[x]]-t1+b2[w[x]-dep[x]+SIZE]-t2;
	for(register int e=poi2[x];e;e=next2[e]){
		b1[s[to2[e]]]--;
		b2[dis[to2[e]]-dep[t[to2[e]]]+SIZE]--;
	}
}
inline int read(){
	char c=getchar();
	int s=0,f=1;
	while(c<'0' or c>'9'){
		if(c=='-')f=-1;
		c=getchar();
	}
	while(c>='0' and c<='9'){
		s=(s<<1)+(s<<3)+c-'0';
		c=getchar();
	}
	return s*f;
}
int main(){
	int n,m;
	n=read();m=read();
	for(register int i=1;i<n;i++){
		int u,v;
		u=read();v=read();
		addt(u,v);
		addt(v,u);
	}
	for(register int i=1;i<=n;i++){
	    w[i]=read();
	}
	dfs(1,0);
	for(register int i=1;i<=m;i++){
		s[i]=read();t[i]=read();
		int lca=glca(s[i],t[i]);
		dis[i]=dep[s[i]]+dep[t[i]]-2*dep[lca];
		js[s[i]]++;
		add1(t[i],i);
		add2(lca,i);
		if(dep[lca]+w[lca]==dep[s[i]])ans[lca]--;
	}
	dfs2(1,0);
	for(register int i=1;i<=n;i++)
	printf("%d ",ans[i]);
	printf("\n");
	return 0;
}
2020/9/26 17:25
加载中...