萌新初学Tarjan LCA,这个代码90分,WA了一个点
#include <bits/stdc++.h>
using namespace std;
vector<int> tree[500001],qus[500001];
int f[500001],diff[500001],vis[500001],fa[500001];
inline int find(int x)
{
if(x==f[x]) return x;
f[x]=find(f[x]);
return f[x];
}
inline void merge(int x,int y)
{
int t1=find(x);
int t2=find(y);
if(t1!=t2) f[t2]=t1;
}
inline void LCA(int u)
{
vis[u]=1;
for(int i=0;i<tree[u].size();i++)
{
int v=tree[u][i];
if(!vis[v])
{
fa[v]=u;
LCA(v);
merge(u,v);
}
}
for(int i=0;i<qus[u].size();i++)
{
if(vis[qus[u][i]]==2)
{
diff[u]++;
diff[qus[u][i]]++;
int fx=find(qus[u][i]);
diff[fx]--;
diff[fa[fx]]--;
}
}
vis[u]=2;
}
inline void dfs(int u)
{
vis[u]=1;
for(int i=0;i<tree[u].size();i++)
{
if(!vis[tree[u][i]])
{
dfs(tree[u][i]);
diff[u]+=diff[tree[u][i]];
}
}
}
int main()
{
int n,tmp1,tmp2;
scanf("%d",&n);
scanf("%d",&tmp1);
int ans=tmp1;
for(int i=2;i<=n;i++)
{
scanf("%d",&tmp2);
qus[tmp2].push_back(tmp1);
qus[tmp1].push_back(tmp2);
tmp1=tmp2;
}
for(int i=1;i<n;i++)
{
scanf("%d %d",&tmp1,&tmp2);
tree[tmp1].push_back(tmp2);
tree[tmp2].push_back(tmp1);
f[i]=i;
}
LCA(1);
memset(vis,0,sizeof(vis));
dfs(1);
for(int i=1;i<=n;i++) if(i!=ans) diff[i]--;
for(int i=1;i<=n;i++) printf("%d\n",diff[i]);
return 0;
}
求dalao们帮帮忙QAQ