附上代码:
#include <bits/stdc++.h>
using namespace std;
const int N=400020,M=400040;
int n,E,a[N],pnt[M],nxt[M],head[N],x[N],y[N];
int par[N],s,mx=0;
void init(){
E=1;
memset(head,-1,sizeof(head));
}
void add(int u,int v){
pnt[E]=v;
nxt[E]=head[u];
head[u]=E;
E++;
}
int find(int x){
if(par[x]==x) return x;
par[x]=find(par[x]);
return par[x];
}
void dfs(int u,int f,int d){
if(mx<=d){
mx=d;
s=u;
}
for(int i=head[u];i!=-1;i=nxt[i]){
if(pnt[i]!=f){
dfs(pnt[i],u,d+1);
}
}
}
int main(){
init();
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<=n;i++) par[i]=i;
for(int i=1;i<n;i++){
scanf("%d%d",&x[i],&y[i]);
if(a[x[i]]==a[y[i]]) par[find(y[i])]=find(x[i]);
}
int fst;
for(int i=1;i<n;i++){
int fx=find(x[i]),fy=find(y[i]);
if(fx!=fy){
add(fx,fy);
add(fy,fx);
fst=fx;
}
}
dfs(fst,0,1);
dfs(s,0,1);
printf("%d",mx/2);
return 0;
}