求助 TLE
查看原帖
求助 TLE
91736
RPChe_楼主2021/7/20 08:56

为什么这份代码在SPOJ狂 T 不止...

明明在洛谷和BZOJ都不 T 啊。

#include<bits/stdc++.h>
#define rep(i,a,b) for(R int i=a;i<=b;i++)
#define R register
#define endl putchar('\n')
const int N=40005;
using namespace std;

int n,m,c[N],head[N],dep[N],f[N][17],sq[N],nex[N];
int bk,key[N],mp[N];
bitset<N> bt[205][205],bc[205],res;
struct edge { int a,b,next; } e[N<<1];
void add(int a,int b) {
	static int cnt=0;
	e[++cnt]=(edge){a,b,head[a]};
	head[a]=cnt;
}

void dfs(int x) {
	dep[x]=dep[f[x][0]]+1;
	rep(i,1,16) f[x][i]=f[f[x][i-1]][i-1];
	for(R int i=head[x];i;i=e[i].next)
		if(e[i].b!=f[x][0]) f[e[i].b][0]=x,dfs(e[i].b);
}
int ask(int x,int y) {
	if(dep[x]<dep[y]) swap(x,y);
	for(R int i=16;i>=0;i--) if(dep[f[x][i]]>=dep[y]) x=f[x][i];
	if(x==y) return x;
	for(R int i=16;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}
bool cmp(int x,int y) { return dep[x]>dep[y]; }
void init() {
	sort(sq+1,sq+n+1),sq[0]=unique(sq+1,sq+n+1)-sq-1;
	rep(i,1,n) c[i]=lower_bound(sq+1,sq+sq[0]+1,c[i])-sq;
	dfs(1);
	static int t[N];
	rep(i,1,n) t[i]=i;
	sort(t+1,t+n+1,cmp);
	rep(i,1,n) {
		if(!key[t[i]]) {
			int x=f[t[i]][0];
			while(dep[x]>max(dep[t[i]]-bk,1)&&!key[x]) x=f[x][0];
			if(!key[x]&&x) key[x]=1,mp[x]=++mp[0];
			nex[t[i]]=x;
		}
	}
	rep(i,1,n) {
		if(key[i]) {
			bc[mp[i]].set(c[i]); int x=f[i][0];
			while(!key[x]&&dep[x]) bc[mp[i]].set(c[x]),x=f[x][0];
			bc[mp[i]].set(c[x]);
			nex[i]=nex[f[i][0]];
		}
	}
	rep(i,1,n) {
		if(key[i]) {
			int x=i; bt[mp[i]][mp[i]].set(c[i]);
			while(dep[x]!=1) bt[mp[i]][mp[nex[x]]]=bt[mp[i]][mp[x]]|bc[mp[x]],x=nex[x];
		}
	}
}

bitset<N> solve(int x,int y) {
	int x1=nex[x],y1=x1;
	res.reset();
	if(dep[x1]<=dep[y]) {
		while(x!=y) res.set(c[x]),x=f[x][0]; res.set(c[x]);
	} else {
		while(x!=x1) res.set(c[x]),x=f[x][0]; res.set(c[x]);
		while(dep[nex[y1]]>=dep[y]) y1=nex[y1];
		res|=bt[mp[x1]][mp[y1]];
		while(y1!=y) res.set(c[y1]),y1=f[y1][0]; res.set(c[y1]);
	}
	return res;
}
int query(int x,int y) {
	int lca=ask(x,y);
	return (solve(x,lca)|solve(y,lca)).count();
}

int main() {
	scanf("%d%d",&n,&m),bk=200;
	rep(i,1,n) scanf("%d",&c[i]),sq[++sq[0]]=c[i];
	rep(i,2,n) {
		int a,b;
		scanf("%d%d",&a,&b);
		add(a,b),add(b,a);
	}
	init();
	rep(i,1,m) {
		int a,b;
		scanf("%d%d",&a,&b);
		printf("%d\n",query(a,b));
	}
	return 0;
}
2021/7/20 08:56
加载中...