#include<bits/stdc++.h>
using namespace std;
const int maxn=200100;
int n,m;
long long col[maxn];
int u[maxn];
int v[maxn];
int fir[maxn];
int next[maxn];
int deep[maxn];
struct node
{
int qu,qv,w;
}q[maxn];
int ans[maxn];
int ol[2*maxn];
int cnt=0;
int in[maxn];
int out[maxn];
int len;
int inl,inr;
int book[maxn];
int book2[maxn];
int alcol=0;
int fa[maxn][20];
void read()
{
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>col[i];
for(int i=1;i<=n-1;i++)
{
cin>>u[i]>>v[i];
fa[v[i]][0]=u[i];
next[i]=fir[u[i]];
fir[u[i]]=i;
}
for(int i=1;i<=m;i++)
cin>>q[i].qu>>q[i].qv,q[i].w=i;
len=sqrt(2*n);
}
long long log2(int x)
{
long long sum=1;
long long hg=0;
while(sum<=x)
sum*=2,hg++;
return hg-1;
}
void dfs(int x)
{
for(int i=1;;i++)
if(fa[x][i-1])fa[x][i]=fa[fa[x][i-1]][i-1];
else break;
deep[x]=deep[fa[x][0]]+1;
ol[++cnt]=x;
in[x]=cnt;
int c=fir[x];
while(c!=0)
{
dfs(v[c]);
c=next[c];
}
ol[++cnt]=x;
out[x]=cnt;
}
bool cmp(node a,node b)
{
return a.qu/len==b.qu/len ? a.qv<b.qv : a.qu/len < b.qu/len ;
}
void px()
{
sort(q+1,q+1+m,cmp);
}
void add(int x)
{
if(++book2[ol[x]]!=2)book[col[ol[x]]]++;
else book[col[ol[x]]]--,book2[ol[x]]=0;
if(book[col[ol[x]]]==0)alcol--;
if(book[col[ol[x]]]==1)alcol++;
}
void del(int x)
{
if(--book2[ol[x]]!=-1)book[col[ol[x]]]--;
else book[col[ol[x]]]++,book2[ol[x]]=1;
if(book[col[ol[x]]]==0)alcol--;
if(book[col[ol[x]]]==1)alcol++;
}
void moder()
{
for(int i=1;i<=m;i++)
{
int graf;
int l,r;
int x1=q[i].qu,x2=q[i].qv;
if(deep[x1]>deep[x2])
{
int cha=deep[x1]-deep[x2];
long long ttttt=log2(cha);
for(int i=0;i<=ttttt;i++)
if((1<<i)&cha)x1=fa[x1][i];
}
if(deep[x1]<deep[x2])
{
int cha=deep[x2]-deep[x1];
long long ttt=log2(cha);
for(int i=0;i<=ttt;i++)
if((1<<i)&cha)x2=fa[x2][i];
}
if(x1!=x2)
{
for(int i=20;i>=0;i--)
if(fa[x1][i]!=fa[x2][i])
x1=fa[x1][i],x2=fa[x2][i];
graf=fa[x1][0];
l=out[q[i].qv];
r=in[q[i].qu];
}else
{
graf=x1;
l=in[q[i].qu];
r=in[q[i].qv];
}
if(i==1)inl=l,inr=l-1;
while(inl<l)del(inl++);
while(inl>l)add(--inl);
while(inr<r)add(++inr);
while(inr>r)del(inr--);
if(book[col[graf]]==0)
ans[q[i].w]=alcol+1;
else ans[q[i].w]=alcol;
}
}
void o()
{
for(int i=1;i<=m;i++)cout<<ans[i]<<'\n';
}
int main()
{
read();
dfs(1);
px();
moder();
o();
}