粘了一篇题解下来对拍,拍了十组 n=4e4 和 m=1e5 的都过了,不知道问题在哪儿
#include<stdio.h>
#include<algorithm>
#include<math.h>
using namespace std;
inline int read(){
int x=0,flag=1; char c=getchar();
while(c<'0'||c>'9'){if(c=='-')flag=0;c=getchar();}
while(c>='0'&&c<='9'){x=x*10+c-48;c=getchar();}
return flag? x:-x;
}
const int N=4e4+7;
const int M=1e5+7;
bool vis[N];
int a[N],b[N],pos[M],Ans[M],dep[N],c[N];
struct Que{
int l,r,p,tag;
bool operator <(const Que &X) const{
if(pos[r]!=pos[X.r]) return pos[r]<pos[X.r];
return pos[r]&1? l<X.l:l>X.l;
}
}q[M];
struct E{
int next,to;
}e[N<<1];
int head[N],cnt=0,in[N],out[N],fa[N][16],sta[M];
inline void add(int id,int to){
e[++cnt]=(E){head[id],to};
head[id]=cnt;
e[++cnt]=(E){head[to],id};
head[to]=cnt;
}
void dfs(int u){
static int tot=0;
sta[in[u]=++tot]=u;
for(int i=1;i<=15;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=head[u];i;i=e[i].next){
int v=e[i].to;
if(v==fa[u][0]) continue;
dep[v]=dep[u]+1,fa[v][0]=u,dfs(v);
}
sta[out[u]=++tot]=u;
}
inline void swap(int &x,int &y){x^=y,y^=x,x^=y;}
int Lca(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
for(int i=15;~i;i--)
if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
if(u==v) return u;
for(int i=15;~i;i--)
if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
int ans=0;
inline void Add(int id){
ans-=c[b[id=sta[id]]]>0;
if(vis[id]^=1) c[b[id]]++,ans+=c[b[id]]>0;
else c[b[id]]--,ans+=c[b[id]]>0;
}
int main(){
freopen("data.in","r",stdin);
freopen("mine.out","w",stdout);
int n=read(),m=read();
for(int i=1;i<=n;i++) b[i]=a[i]=read();
sort(a+1,a+1+n); int sz=unique(a+1,a+1+n)-(a+1);
for(int i=1;i<=n;i++)
b[i]=lower_bound(a+1,a+1+n,b[i])-a;
for(int i=1;i<n;i++) add(read(),read());
dep[1]=1,dfs(1);
for(int i=1;i<=m;i++){
int u=read(),v=read(),l,r,tag;
int lca=Lca(u,v);
if(lca==u) l=in[u],r=in[v];
else if(lca==v) l=in[v],r=in[u];
else if(out[u]<in[v]) l=out[u],r=in[v],tag=1;
else l=out[v],r=in[u],tag=1;
q[i]=(Que){l,r,i,tag? lca:0};
}
sz=(int)(sqrt(2*n)+0.5);
for(int i=1;i<=2*n;i++) pos[i]=(i-1)/sz+1;
sort(q+1,q+1+m); int l=1,r=0;
for(int i=1;i<=m;i++){
int L=q[i].l,R=q[i].r;
while(l<L) Add(l++);
while(l>L) Add(--l);
while(r>R) Add(r--);
while(r<R) Add(++r);
Ans[q[i].p]=ans;
if(q[i].tag&&!c[b[q[i].tag]]) Ans[q[i].p]++;
}
for(int i=1;i<=m;i++) printf("%d\n",Ans[i]);
}