为什么全T啊!求助
查看原帖
为什么全T啊!求助
308464
奇米楼主2020/6/2 10:55
#pragma GCC optimize(3,"Ofast","inline")
#pragma GCC target("avx,avx2")
#include <bits/stdc++.h>
#define For(i,a,b) for ( register int i=(a);i<=(b);++i )
#define Dow(i,b,a) for ( register int i=(b);i>=(a);--i )
#define GO(i,x) for ( register int i=head[x];i;i=e[i].nex )
#define mem(x,s) memset(x,s,sizeof(x))
#define cpy(x,s) memcpy(x,s,sizeof(x))
#define YES return puts("YES"),0
#define NO return puts("NO"),0
#define GG return puts("-1"),0
#define pb push_back
using namespace std;

inline int read()
{
	int sum=0,ff=1; char ch=getchar();
	while(!isdigit(ch))
	{
		if(ch=='-') ff=-1;
		ch=getchar();
	}
	while(isdigit(ch))
		sum=sum*10+(ch^48),ch=getchar();
	return sum*ff;
}

const int mod=1e9+7;
const int mo=998244353;
const int N=1e5+5;

int n,m,cnt,Rt,lyx,ans,Siz,tim,tot,sxy;
int siz[N],top[N],f[N][26],dep[N],a[N],yyx[N];
int head[N],in[N],out[N],bel[N],b[N],xm[N];
int lx[N],ly[N],all,son[N],cntl[N],cntr[N];

struct Node
{
	int nex,to;
};
Node e[N<<1];

struct Que
{
	int l,r,id,inv;
	inline bool friend operator < (const Que&a,const Que&b)
	{
		return (bel[a.l]^bel[b.l])?bel[a.l]<bel[a.l]:(!(bel[a.l]&1))?a.r<b.r:a.r>b.r;
	}
};
Que q[N<<5];

inline void jia(int u,int v)
{
	e[++cnt].nex=head[u];
	head[u]=cnt;
	e[cnt].to=v;
}

inline void dfs(int u,int fa)
{
	dep[u]=dep[fa]+1;
	siz[u]=1;
	in[u]=++tim;
	f[u][0]=fa;
	For(i,1,21) f[u][i]=f[f[u][i-1]][i-1];
	GO(i,u)
	{
		int v=e[i].to;
		if(v==fa) continue;
		dfs(v,u);
		siz[u]+=siz[v];
		if(siz[v]>siz[son[u]]) son[u]=v;
	}
	out[u]=tim;
}

inline void dfs2(int u,int Tp)
{
	top[u]=Tp;
	if(son[u]) dfs2(son[u],Tp);
	GO(i,u)
	{
		int v=e[i].to;
		if(v==son[u]||v==f[u][0]) continue;
		dfs2(v,v);
	}
}

inline int find(int x,int y)
{
	int u;
	while(top[x]!=top[y]) u=top[x],x=f[u][0];
	return (x==y)?u:son[y];
}

inline int LCA(int x,int y)
{
	if(dep[x]>dep[y]) swap(x,y);
	Dow(i,20,0) if(dep[y]-(1<<i)>=dep[x]) y=f[y][i];
	if(x==y) return x;
	Dow(i,20,0) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}

inline void get(int x)
{
	if(x==Rt) lx[++all]=1,ly[all]=n;
	else 
	{
		int z=LCA(Rt,x);
		if(z!=x) 
			lx[++all]=in[x],ly[all]=out[x];
		else 
		{
			int y=find(x,Rt);
			if(in[y]-1>=1) lx[++all]=1,ly[all]=in[y]-1;
			if(out[y]+1<=n) lx[++all]=out[y]+1,ly[all]=n;
		}
	}
}

inline void add(int l,int r,int ll,int rr,int id)
{
	q[++sxy]=(Que){r,rr,id,1};
	q[++sxy]=(Que){l-1,rr,id,-1};
	q[++sxy]=(Que){ll-1,r,id,-1};
	q[++sxy]=(Que){l-1,ll-1,id,1};
}

inline void build(int x,int y,int id)
{
	all=0;get(x);
	int tmp=all; get(y);
	For(i,1,tmp) 
		For(j,tmp+1,all) add(lx[i],ly[i],lx[j],ly[j],id);
}

int main()
{
	n=read(),m=read();
	Rt=1;Siz=sqrt(n);
	For(i,1,n) a[i]=b[i]=read(),bel[i]=(i-1)/Siz+1;
	sort(b+1,b+n+1);
	tot=unique(b+1,b+n+1)-b-1;
	For(i,1,n) a[i]=lower_bound(b+1,b+tot+1,a[i])-b;
	For(i,1,n-1) 
	{
		int x,y;
		x=read(),y=read();
		jia(x,y); jia(y,x);
	}
	dfs(1,0);
	dfs2(1,1);
	For(q,1,m)
	{
		int op,x,y;
		op=read();
		if(op==1) Rt=read();
		else 
		{
			x=read(),y=read();
			yyx[q]=1;
			build(x,y,q);
		}
	}
	For(i,1,sxy) if(q[i].l>q[i].r) swap(q[i].l,q[i].r);
	sort(q+1,q+sxy+1);
	int l=1,r=1,sum=0;
	For(i,1,sxy)
	{
		while(l<q[i].l) sum+=cntr[a[++l]],++cntl[a[l]];
		while(l>q[i].l) sum-=cntr[a[l]],--cntl[a[l--]];
		while(r<q[i].r) sum+=cntl[a[++r]],++cntr[a[r]];
		while(r>q[i].r) sum-=cntl[a[r]],--cntr[a[r--]];
		xm[q[i].id]+=sum*q[i].inv;
	}
	For(i,1,m) if(yyx[i]) printf("%d\n",xm[i]);
	return 0;
}


2020/6/2 10:55
加载中...