WA on test 40。。。
查看原帖
WA on test 40。。。
341373
Autofreeze楼主2021/2/1 13:56

调自闭了。。。

屑代码

#include<bits/stdc++.h>
#define N 200001
#define MAX 301
#define re register
#define inf 1e18
using namespace std;
typedef long long ll;
typedef double db;
const ll mod=1000000007; 
inline void read(re ll &ret)
{
    ret=0;re char c=getchar();re bool pd=false;
    while(!isdigit(c)){pd|=c=='-';c=getchar();}
    while(isdigit(c)){ret=(ret<<1)+(ret<<3)+(c&15);c=getchar();}
    ret=pd?-ret:ret;
    return;
}
ll n,a[N],x,y,phi[N],num[N],dfn[N],back[N],s[N],head,st[N<<1][21],minn[N<<1][21],lg[N<<1],fir[N<<1],dep[N];
ll F[N],deep[N<<1],node[N<<1],cnt,f[N],dis[N];
vector<long long>v[N],g;
bool prime[N];
ll p[N],tot,mu[N];
inline ll qpow(re ll a,re ll b)
{
	re ll ret=1;
	while(b)
	{
		if(b&1)
			ret*=a,ret%=mod;
		b>>=1;
		a*=a;
		a%=mod;
	}
	return ret;
}
inline ll inv(re ll x)
{
	return qpow(x,mod-2);
}
inline void init()
{
	for(re long long i=2;i<=n;i++)
		prime[i]=true;
	phi[1]=1;
	mu[1]=1;
	for(re long long i=2;i<=n;i++)
	{
		if(prime[i])
			p[++tot]=i,mu[i]=-1,phi[i]=i-1;
		for(re long long j=1;j<=tot&&p[j]*i<=n;j++)
		{
			prime[p[j]*i]=false;
			if(!(i%p[j]))
			{
				phi[p[j]*i]=phi[i]*p[j];
				mu[p[j]*i]=0;
				break;
			}
			mu[p[j]*i]=-mu[i];
			phi[p[j]*i]=phi[i]*phi[p[j]];
		}
	}
	return;
}
inline void dfs(re ll ver,re ll fa)
{
	dfn[ver]=++tot;
	if(fa)
		dis[ver]=dis[fa]+1;
	dis[ver]%=mod;
	dep[ver]=dep[fa]+1;
	deep[++cnt]=dep[ver];
	node[cnt]=ver;
	for(re long long i=0;i<v[ver].size();i++)
	{
		re ll to=v[ver][i];
		if(to==fa)
			continue;
		dfs(to,ver);
		deep[++cnt]=dep[ver];
		node[cnt]=ver;
	}
	return;
}
inline bool cmp(re ll x,re ll y)
{
	return dfn[x]<dfn[y];
}
inline ll lca(re ll p,re ll q)
{
	if(dfn[p]>dfn[q])
		swap(dfn[p],dfn[q]);
	re ll l=fir[p],r=fir[q];
	re ll k=lg[r-l+1];
	if(minn[l][k]<=minn[r-(1<<k)+1][k])
		return st[l][k];
	else
		return st[r-(1<<k)+1][k];
}
ll ret;
bool vis[N];
inline void df5(re ll ver,re ll fa)
{
	node[ver]=deep[ver]=0;
	if(vis[ver])
		node[ver]=num[ver];
	for(re long long i=0;i<v[ver].size();i++)
	{
		re ll to=v[ver][i];
		if(to==fa)
			continue;
		df5(to,ver);
		ret+=deep[ver]*node[to]%mod+(deep[to]+(dis[to]-dis[ver])*node[to])*node[ver]%mod;
		ret%=mod;
		node[ver]+=node[to];
		deep[ver]+=deep[to]+(dis[to]-dis[ver])*node[to];
		deep[ver]%=mod;
		node[ver]%=mod;
	}
	v[ver].clear();
	vis[ver]=false;
	return;
}
signed main()
{
	read(n);
	init();
	tot=0;
	for(re long long i=1;i<=n;i++)
		read(a[i]),num[i]=phi[a[i]],back[a[i]]=i;
	for(re long long i=1;i<n;i++)
	{
		read(x);
		read(y);
		v[x].push_back(y);
		v[y].push_back(x);
	}
	dfs(1,0);
	for(re long long i=1;i<=n;i++)
		v[i].clear();
	lg[1]=0;
	for(re long long i=1;(i<<1)<=cnt;i++)
		lg[i<<1]=lg[i<<1|1]=lg[i]+1;
	for(re long long i=1;i<=cnt;i++)
	{
		minn[i][0]=deep[i];
		st[i][0]=node[i];
		if(!fir[node[i]])
			fir[node[i]]=i;
	}
	for(re long long j=1;j<=18;j++)
		for(re long long i=1;i+(1<<j)-1<=cnt;i++)
		{
			if(minn[i][j-1]<=minn[i+(1<<j-1)][j-1])
				minn[i][j]=minn[i][j-1],st[i][j]=st[i][j-1];
			else
				minn[i][j]=minn[i+(1<<j-1)][j-1],st[i][j]=st[i+(1<<j-1)][j-1];
		}
	for(re ll d=1;d<=n;d++)
	{
		g.clear();
		for(re long long j=d;j<=n;j+=d)
		{
			if(back[j]!=1)
				g.push_back(back[j]);
			vis[back[j]]=true;
		}
		sort(g.begin(),g.end(),cmp);
		head=0;
		s[++head]=1;
		for(re long long j=0;j<g.size();j++)
		{
			re ll now=g[j];
			if(head==1)
				s[++head]=now;
			else
			{
				re ll tmp=lca(s[head],now);
		//		cout<<s[head]<<" "<<now<<
				while(head>1&&dfn[s[head-1]]>=dfn[tmp])
				{
					v[s[head-1]].push_back(s[head]);
					v[s[head]].push_back(s[head-1]);
					head--;
				}
				if(s[head]!=tmp)
				{
					v[s[head]].push_back(tmp);
					v[tmp].push_back(s[head]);
					s[head]=tmp;
				}
				s[++head]=now;
			}
		}
		while(head>1)
			v[s[head]].push_back(s[head-1]),v[s[head-1]].push_back(s[head]),head--;
		ret=0;
		df5(1,0);
		F[d]=ret;
		F[d]<<=1;
		v[0].clear();
	}
	for(re long long i=1;i<=n;i++)
	{
		for(re long long j=i,x=1;j<=n;x++,j+=i)
		{
			f[i]+=mu[x]*F[j]%mod;
			f[i]=(f[i]%mod+mod)%mod;
		}
	}
	re ll ans=0;
	for(re ll i=1;i<=n;i++)
		ans+=(long long)i*(long long)inv(phi[i])%mod*(long long)f[i]%mod,ans%=mod;
	printf("%lld\n",ans*(long long)inv(n)%mod*(long long)inv(n-1)%mod);
    exit(0);
}
2021/2/1 13:56
加载中...