求助卡常
查看原帖
求助卡常
55201
clamee楼主2020/9/11 20:10

复杂度 O(nn)O(n\sqrt{n}) ,最后三个点卡不过去。

#include<bits/stdc++.h>
using namespace std;

#define il inline
#define rg register

il int read()
{
	int re=0,k=1;char ch=getchar();
	while(ch>'9'||ch<'0'){if(ch=='-')k=-1;ch=getchar();}
	while(ch<='9'&&ch>='0'){re=re*10+ch-48;ch=getchar();}
	return re*k;
}

il void write(int x)
{
	if(x<0)return putchar('-'),write(-x),void();
	if(x<10)return putchar(x+48),void();
	return write(x/10),write(x%10),void();
}

const int BB=351;

int h[100005],l[100005],head[100005],tot,d[100005],fa[100005],top[100005],n,a[100005],sz[100005],son[100005],f[405][100005],p[405],ls,t[100005],c[100005],g[100005];
int fs[2000005],tt;
bool vis[100005];
int vis2[100005],vis1[100005];
struct ss{
    int node,nxt;
}e[100005];
void add(int u,int v)
{
    e[++tot].nxt=head[u];
    e[tot].node=v;
    head[u]=tot;
}
void dfs(int x,int ffa,int dp)
{
	fa[x]=ffa;
	d[x]=dp;
	sz[x]=1;
	
	int maxsz=-1;
	for(int i=head[x];i;i=e[i].nxt)
	{
		int k=e[i].node;
		if(k==ffa)continue;
		if(!vis[k])
			g[k]=g[x];
		else g[k]=k;
		dfs(k,x,dp+1);
		sz[x]+=sz[k];
		if(maxsz<sz[k])maxsz=sz[k],son[x]=k;
	}
}
void ddfs(int x,int tt)
{
	//id[x]=++cnt;
	//ans[cnt]=w[x];
	top[x]=tt;
//	ddfs(son[x],tt);
	l[x]=h[a[x]];
	//int p=h[a[x]];
	if(!son[x])return;
	h[a[x]]=x;
	ddfs(son[x],tt);
	for(int i=head[x];i;i=e[i].nxt)
	{
		int k=e[i].node;
		if(fa[x]==k||son[x]==k)continue;
		ddfs(k,k);
	}
	h[a[x]]=l[x];
}
int lca(int x,int y)
{
	while(top[x]!=top[y])
	{
		if(d[top[x]]<d[top[y]])swap(x,y);
		x=fa[top[x]];
	}
	if(d[x]>d[y])swap(x,y);
	return x;
}

void dp(int x,int fa,int *ff,int *s,int ls,int ans)
{
	if(!x)return;
	s[++ls]=a[x];
	if(t[a[x]]&&!vis2[a[x]])ans++;
	vis2[s[ls]]++;
	if(vis[x])
	{
		for(int i=1;i<=ls;i++)
		{
			if(!t[s[i]])tt++;
			t[s[i]]++,vis2[s[i]]--;
		}
		ff[x]=tt;
	}
	else ff[x]=ans;
	for(rg int i=head[x];i;i=e[i].nxt)
	{
		if(e[i].node==fa)continue;
		if(vis[x])
		{	
			dp(e[i].node,x,ff,s+ls,0,0);
		}
		else
		{
			dp(e[i].node,x,ff,s,ls,ans);
		}
	}
	if(vis[x])
	{
		for(int i=1;i<=ls;i++)
		{
			t[s[i]]--;vis2[s[i]]++;
			if(!t[s[i]])tt--;
		}
	}
	vis2[s[ls]]--;
}
int sol(int u,int v,int typ)
{
	int lans=0;
   	int x=u,y=v,LCA=lca(u,v);
   	while(x!=LCA){if(!vis1[a[x]])lans++;vis1[a[x]]=1;x=fa[x];}
   	while(y!=LCA){if(!typ||y!=v){if(!vis1[a[y]])lans++;vis1[a[y]]=1;}y=fa[y];}
	if(!vis1[a[LCA]]&&(!typ||LCA!=v))lans++;
   	while(u!=LCA){vis1[a[u]]=0;u=fa[u];}
   	while(v!=LCA){vis1[a[v]]=0;v=fa[v];}
   	return lans;
}

int b[100005],cnt,q;

int main()
{
	freopen("2.in","r",stdin);
	freopen("2.out","w",stdout);
    n=read();q=read();
   	for(rg int i=1;i<=n;i++)
   	{
   		b[++cnt]=a[i]=read();
   		c[i]=i;
   	}
   	for(rg int i=1,u,v;i<n;i++)
   	{
   		u=read();v=read();
   		add(u,v);add(v,u);
   	}
   	srand(time(0));
   	sort(b+1,b+cnt+1);
   	cnt=unique(b+1,b+cnt+1)-b-1;
   	random_shuffle(c+2,c+n+1);
   	for(rg int i=1;i<=n;i++)
   	{
   		if(i<=BB)p[i]=c[i],vis[c[i]]=1;
   		a[i]=lower_bound(b+1,b+cnt+1,a[i])-b;
   	}
	g[1]=1;
   	dfs(1,0,1);
   	ddfs(1,1);
   	for(rg int i=1;i<=BB;i++)
   	{
   		dp(p[i],0,f[i],fs,0,0);
   	}
   	int lans=0;
	memset(vis1,0,sizeof(vis1));
	memset(vis2,0,sizeof(vis2));
   	for(rg int i=1,u,v;i<=q;i++)
   	{
   		u=read()^lans;v=read();
   		if(g[u]==g[v])
   		{
   			lans=0;
   			int x=u,y=v,LCA=lca(u,v);
   			while(x!=LCA){if(!vis1[a[x]])lans++;vis1[a[x]]=1;x=fa[x];}
   			while(y!=LCA){if(!vis1[a[y]])lans++;vis1[a[y]]=1;y=fa[y];}
			if(!vis1[a[LCA]])lans++;
   			while(u!=LCA){vis1[a[u]]=0;u=fa[u];}
   			while(v!=LCA){vis1[a[v]]=0;v=fa[v];}
   			write(lans);puts("");continue;
   		}
		int LCA=lca(u,v);
		memset(vis1,0,sizeof(vis1));
   		if(d[g[u]]<d[LCA]||d[g[v]]<d[LCA])
   		{
   			if(d[g[u]]<d[LCA])swap(u,v);
   			int x=u,y=v,z=u,A=0,B=0,C=0,AB=0,BC=0,AC=0,ABC=0,xx=u,yy=v,ty;
   			while(!vis[x]&&x!=LCA){
				if(!vis1[a[x]])
					A++;
				vis1[a[x]]=x;
				x=fa[x];
			}
   			while(!vis[y]&&y!=LCA)
   			{
   				if(!vis2[a[y]])
   				{
   					B++;
   					if(vis1[a[y]])
   					{
   						AB++;
   					}
   				}
   				vis2[a[y]]=y;y=fa[y];
   			}
   			while(d[g[fa[g[z]]]]>d[LCA])z=fa[g[z]];
   			ty=z=g[z];
   			z=fa[z];
   			int zz=z;
   			while(!vis[z]&&z!=LCA)
   			{
   				if(!vis2[a[z]])
   				{
   					B++;
   					if(vis1[a[z]])
   					{
   						AB++;
   					}
   				}
   				vis2[a[z]]=z;z=fa[z];
   			}
   			if(!vis2[a[LCA]])
   			{
   				B++;
   				if(vis1[a[LCA]])AB++;
   			}
			vis2[a[LCA]]=LCA;
 	  		while(!vis[yy]&&yy!=LCA)
 	  		{
 	  			if(vis2[a[yy]]&&vis1[a[yy]])
 	  			{
 	  				if((d[l[vis1[a[yy]]]]>=d[ty]&&d[l[vis1[a[yy]]]]<=d[x]))
 	  					ABC++;
 	  			}
 	  			vis2[a[yy]]=0;yy=fa[yy];
 	  		}
 	  		while(!vis[zz]&&zz!=LCA)
 	  		{
 	  			if(vis2[a[zz]]&&vis1[a[zz]])
 	  			{
 	  				if((d[l[vis1[a[zz]]]]>=d[ty]&&d[l[vis1[a[zz]]]]<=d[x]))
 	  					ABC++;
 	  			}
 	  			vis2[a[zz]]=0;zz=fa[zz];
 	  		}
			if(vis2[a[LCA]])
   			{
   				if(vis1[a[LCA]])
					if((d[l[vis1[a[LCA]]]]>=d[ty]&&d[l[vis1[a[LCA]]]]<=d[x]))
 	  					ABC++;
   			}
			vis2[a[LCA]]=0;
 	  		while(!vis[xx]&&xx!=LCA){vis1[a[xx]]=0;xx=fa[xx];}
 	  		for(rg int j=1;j<=BB;j++)
 	  			if(p[j]==x){x=j;break;}
 	  		for(rg int j=1;j<=BB;j++)
 	  			if(p[j]==ty){ty=j;break;}
   			C=f[x][p[ty]];
   			if(!vis[v])BC=f[x][v];
   			if(!vis[u])AC=f[ty][u];
   			lans=A+B+C-AB-BC-AC+ABC;
   			write(lans);puts("");continue;
   		}
   		int x=u,y=v,A=0,B=0,C=0,AB=0,BC=0,AC=0,ABC=0,xx=u,yy=v;
   		while(!vis[x]&&x!=LCA)
   		{
   			if(!vis1[a[x]])
   				A++;
   			vis1[a[x]]=x;
   			x=fa[x];
   		}
   		while(!vis[y]&&y!=LCA)
   		{
   			if(!vis2[a[y]])
   			{
   				B++;
   				if(vis1[a[y]])
   				{
   					AB++;
   				}
   			}
   			vis2[a[y]]=y;y=fa[y];
   		}
   		while(!vis[yy]&&yy!=LCA)
   		{
   			if(vis2[a[yy]]&&vis1[a[yy]])
   			{
   				if(((d[l[vis1[a[yy]]]]>=d[LCA])&&(d[l[vis1[a[yy]]]]<=d[x]))||((d[l[vis2[a[yy]]]]>=d[LCA])&&(d[l[vis2[a[yy]]]]<=d[y])))
   					ABC++;
   			}
   			vis2[a[yy]]=0;yy=fa[yy];
   		}
   		while(!vis[xx]&&xx!=LCA){vis1[a[xx]]=0;xx=fa[xx];}
   		for(rg int j=1;j<=BB;j++)
   			if(p[j]==x){x=j;break;}
   		for(rg int j=1;j<=BB;j++)
   			if(p[j]==y){y=j;break;}
   		C=f[x][p[y]];
   		if(!vis[v])BC=f[x][v];
   		if(!vis[u])AC=f[y][u];
   		lans=A+B+C-AB-BC-AC+ABC;
   		write(lans);puts("");
   	}
}
2020/9/11 20:10
加载中...