我交上去TLE,本地测试极限数据很慢很慢很慢,感觉已经不仅仅是常数问题了,等很久很久才跑一组极限数据出来。然后我拿俩题解拍,跑的快的一批。
对了,应该只是跑得慢,跑出来都是对的。
是LCA的锅吗(不想写树剖LCA怎么办。。。)
#include<bits/stdc++.h>
using namespace std;
const int N=40010,M=100010;
int n,m,a[N],js[N];
int belong[N],cnt[N],vis[N<<1],now,ans[M];
int e,hd[N],to[N<<1],nxt[N<<1];
int tim,dfn[N<<1],dfnval[N<<1],fst[N],lst[N];
int f[N][20],dep[N];
struct pos{
int l,r,id,tmp;
}q[M];
bool cmp(pos x,pos y){
if(belong[x.l]!=belong[y.l]) return belong[x.l]<belong[y.l];
if(belong[x.l]&1) return x.r<y.r;
else x.r>y.r;
}
void add(int u,int v){
to[++e]=v;
nxt[e]=hd[u];
hd[u]=e;
}
void dfs(int u,int fa){
dfn[++tim]=u;fst[u]=tim;dfnval[tim]=a[u];
dep[u]=dep[fa]+1;f[u][0]=fa;
for(int i=1;i<=16;i++) f[u][i]=f[f[u][i-1]][i-1];
for(int i=hd[u];i;i=nxt[i]){
int v=to[i];
if(v==fa) continue;
dfs(v,u);
}
dfn[++tim]=u;lst[u]=tim;dfnval[tim]=a[u];
return;
}
int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
if(dep[x]!=dep[y]){
for(int i=16;i>=0;i--){
if(dep[f[x][i]]>=dep[y]) x=f[x][i];
if(x==y) return x;
}
}
for(int i=16;i>=0;i--)
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void init(){
int sz=n*2.0/sqrt(m*2/3),num=ceil((double)n/sz);
for(int i=1;i<=num;i++){
int tmp=min(n,i*sz);
for(int j=(i-1)*sz;i<=tmp;i++)
belong[j]=i;
}
return;
}
void del(int x){
if(vis[x]) now-=!--cnt[dfnval[x]];
else now+=!cnt[dfnval[x]]++;
vis[x]^=1;
}
void add(int x){
if(vis[x]) now-=!--cnt[dfnval[x]];
else now+=!cnt[dfnval[x]]++;
vis[x]^=1;
}
int main(){
// freopen("a.in","r",stdin);
// freopen("a.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]),js[i]=a[i];
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
init();
sort(js+1,js+n+1);
int len=unique(js+1,js+n+1)-js-1;
for(int i=1;i<=n;i++) a[i]=lower_bound(js+1,js+len+1,a[i])-js;
dfs(1,0);
// for(int i=1;i<=2*n;i++)
// cout<<dfn[i]<<" ";
// cout<<endl;
for(int i=1,x,y;i<=m;i++){
scanf("%d%d",&x,&y);
q[i].id=i;
if(fst[x]>fst[y]) swap(x,y);
int tmp=lca(x,y);//cout<<x<<" "<<y<<" "<<tmp<<endl;
if(tmp==x) q[i].l=fst[x],q[i].r=fst[y],q[i].tmp=tmp;
else q[i].l=lst[x],q[i].r=fst[y],q[i].tmp=tmp;
}
sort(q+1,q+m+1,cmp);
int l=1,r=0;
for(int i=1;i<=m;i++){
int ql=q[i].l,qr=q[i].r,qlca=q[i].tmp;
while(l<ql) del(l++);
while(l>ql) add(--l);
while(r<qr) add(++r);
while(r>qr) del(r--);
// cout<<l<<" "<<r<<endl;
if(ql!=qlca) add(qlca),ans[q[i].id]=now,del(qlca);
else ans[q[i].id]=now;
}
for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
return 0;
}