萌新求助
查看原帖
萌新求助
40581
jeffqi楼主2020/7/15 12:58

一直TLE on 10,复杂度应该对的,有没有大佬可以看看有什么问题。

#include<bits/stdc++.h>
#define rep(i,a,b) for (int i = (a); i <= (b); ++i)
#define drep(i,a,b) for (int i = (a); i >= (b); --i)
#define grep(i,u) for (int i = head[u],v = e[i].v; i; v = e[i = e[i].nxt].v)
#define il inline
#define LL long long
#define ULL unsigned LL
#define pb push_back
#define mp make_pair
#define pii pair<int,int>
#define pll pair<LL,LL>
#define fi first
#define se second
using namespace std;
il int read() {
	int res = 0,f = 1; char ch = getchar(); while (!isdigit(ch)) {if (ch == '-') f = -f; ch = getchar();}
	while (isdigit(ch)) {res = res*10+ch-'0'; ch = getchar();} return res*f;
}
namespace qiqi {
	const int N = 2e5+5,P = 1e9+7; int n,a[N],p[N],v[N],ecnt,head[N],cnt,pri[N],phi[N],mu[N],f[N],dfn[N],dep[N],lg[N<<1],stk[N],top,fa[N],in[N],q[N],siz[N];
	struct ST {int p,d; il bool friend operator < (ST a,ST b) {return a.d < b.d;}} st[N<<1][19]; il bool cmp(int x,int y) {return dfn[x] < dfn[y];}
	struct Edge {int v,nxt;} e[N<<1]; il void add(int u,int v) {e[++ecnt] = (Edge){v,head[u]}; head[u] = ecnt;}
	il int pow(int b,int k) {int a = 1; while (k) {if (k&1) a = 1LL*a*b%P; b = 1LL*b*b%P; k >>= 1;} return a;}
	void dfs(int u,int fa,int d) {st[dfn[u] = ++cnt][0] = (ST){u,dep[u] = d}; grep(i,u) if (v != fa) {dfs(v,u,d+1); st[++cnt][0] = (ST){u,d};}}
	il void init(int n) {
		phi[1] = mu[1] = 1; rep(i,2,n) {
			if (!phi[i]) {pri[++cnt] = i; mu[i] = -1; phi[i] = i-1;} rep(j,1,cnt) {
				if (i*pri[j] > n) break; if (!(i%pri[j])) {phi[i*pri[j]] = phi[i]*pri[j]; break;} phi[i*pri[j]] = phi[i]*(pri[j]-1); mu[i*pri[j]] = -mu[i];
			}
		}
		rep(i,1,n) {int x = 1LL*i*pow(phi[i],P-2)%P; rep(j,1,n/i) f[i*j] = (f[i*j]+x*mu[j])%P;}
		cnt = 0; dfs(1,0,1); in[0] = lg[0] = -1; rep(i,1,cnt) lg[i] = lg[i>>1]+1; rep(j,1,18) rep(i,1,cnt-(1<<j)+1) st[i][j] = min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
	}
	il int LCA(int x,int y) {x = dfn[x]; y = dfn[y]; if (x > y) swap(x,y); int k = lg[y-x+1]; return min(st[x][k],st[y-(1<<k)+1][k]).p;}
	il void ins(int x) {
		if (top<2) {stk[++top] = x; return;} int lca = LCA(x,stk[top]);
		while (dfn[lca] < dfn[stk[top]]) {if (dfn[lca] > dfn[stk[top-1]]) {++in[fa[stk[top--]] = lca]; break;} ++in[fa[stk[top]] = stk[top-1]]; --top;}
		if (stk[top] != lca) stk[++top] = lca; if (x != lca) stk[++top] = x;
	}
	il int solve(int k) {
		int res = cnt = 0,tail = 0,s = 0; rep(i,1,n) if (!(v[p[i]]%k)) a[++cnt] = p[i]; if (a[1] != 1) ins(1);
		rep(i,1,cnt) ins(a[i]); while (top) {fa[stk[top]] = stk[top-1]; ++in[stk[--top]];} rep(i,1,cnt) {if (!in[a[i]]) q[++tail] = a[i]; siz[a[i]] = phi[v[a[i]]]; s = (s+siz[a[i]])%P;}
		rep(i,1,tail) {int v = q[i],u = fa[v]; if (!(--in[u])) q[++tail] = u; res = (res+1LL*siz[v]*(s-siz[v])*(dep[v]-dep[u])%P)%P; siz[u] = (siz[u]+siz[v])%P; siz[v] = fa[v] = 0;} return res;
	}
	void main() {
		n = read(); rep(i,1,n) v[i] = read(); rep(i,1,n-1) {int x = read(),y = read(); add(x,y); add(y,x);} init(n); rep(i,1,n) p[i] = i;
		sort(p+1,p+n+1,cmp); int ans = 0; rep(i,1,n) ans = (ans+1LL*f[i]*solve(i)%P)%P; printf("%d\n",(2LL*ans*pow(n,P-2)%P*pow(n-1,P-2)%P+P)%P);
	}
}
int main() {qiqi::main(); return 0;}
2020/7/15 12:58
加载中...