为什么这份代码在SPOJ狂 T 不止...
明明在洛谷和BZOJ都不 T 啊。
#include<bits/stdc++.h>
#define rep(i,a,b) for(R int i=a;i<=b;i++)
#define R register
#define endl putchar('\n')
const int N=40005;
using namespace std;
int n,m,c[N],head[N],dep[N],f[N][17],sq[N],nex[N];
int bk,key[N],mp[N];
bitset<N> bt[205][205],bc[205],res;
struct edge { int a,b,next; } e[N<<1];
void add(int a,int b) {
static int cnt=0;
e[++cnt]=(edge){a,b,head[a]};
head[a]=cnt;
}
void dfs(int x) {
dep[x]=dep[f[x][0]]+1;
rep(i,1,16) f[x][i]=f[f[x][i-1]][i-1];
for(R int i=head[x];i;i=e[i].next)
if(e[i].b!=f[x][0]) f[e[i].b][0]=x,dfs(e[i].b);
}
int ask(int x,int y) {
if(dep[x]<dep[y]) swap(x,y);
for(R int i=16;i>=0;i--) if(dep[f[x][i]]>=dep[y]) x=f[x][i];
if(x==y) return x;
for(R int i=16;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
bool cmp(int x,int y) { return dep[x]>dep[y]; }
void init() {
sort(sq+1,sq+n+1),sq[0]=unique(sq+1,sq+n+1)-sq-1;
rep(i,1,n) c[i]=lower_bound(sq+1,sq+sq[0]+1,c[i])-sq;
dfs(1);
static int t[N];
rep(i,1,n) t[i]=i;
sort(t+1,t+n+1,cmp);
rep(i,1,n) {
if(!key[t[i]]) {
int x=f[t[i]][0];
while(dep[x]>max(dep[t[i]]-bk,1)&&!key[x]) x=f[x][0];
if(!key[x]&&x) key[x]=1,mp[x]=++mp[0];
nex[t[i]]=x;
}
}
rep(i,1,n) {
if(key[i]) {
bc[mp[i]].set(c[i]); int x=f[i][0];
while(!key[x]&&dep[x]) bc[mp[i]].set(c[x]),x=f[x][0];
bc[mp[i]].set(c[x]);
nex[i]=nex[f[i][0]];
}
}
rep(i,1,n) {
if(key[i]) {
int x=i; bt[mp[i]][mp[i]].set(c[i]);
while(dep[x]!=1) bt[mp[i]][mp[nex[x]]]=bt[mp[i]][mp[x]]|bc[mp[x]],x=nex[x];
}
}
}
bitset<N> solve(int x,int y) {
int x1=nex[x],y1=x1;
res.reset();
if(dep[x1]<=dep[y]) {
while(x!=y) res.set(c[x]),x=f[x][0]; res.set(c[x]);
} else {
while(x!=x1) res.set(c[x]),x=f[x][0]; res.set(c[x]);
while(dep[nex[y1]]>=dep[y]) y1=nex[y1];
res|=bt[mp[x1]][mp[y1]];
while(y1!=y) res.set(c[y1]),y1=f[y1][0]; res.set(c[y1]);
}
return res;
}
int query(int x,int y) {
int lca=ask(x,y);
return (solve(x,lca)|solve(y,lca)).count();
}
int main() {
scanf("%d%d",&n,&m),bk=200;
rep(i,1,n) scanf("%d",&c[i]),sq[++sq[0]]=c[i];
rep(i,2,n) {
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
init();
rep(i,1,m) {
int a,b;
scanf("%d%d",&a,&b);
printf("%d\n",query(a,b));
}
return 0;
}