求助
查看原帖
求助
800499
suzhikz楼主2024/9/11 21:17
#include<bits/stdc++.h>
#define ll long long
#define reg register
#define db double
#define il inline
using namespace std;
const int N=1e6+5; 
int fa[N],n;
ll w[N],ans;
vector<int>g[N];
bool vis[N],vis2[N];
ll dp[N][2];
int mark;
void dfs(int x){
	vis[x]=1;vis2[x]=1;
	dp[x][1]=w[x];
	dp[x][0]=0;
	for(auto i:g[x]){
		if(vis2[i])continue; 
		dfs(i);
		dp[x][1]+=dp[i][0];
		dp[x][0]+=max(dp[i][0],dp[i][1]);
	}
} 
int check(int x){
	vis[x]=1;
	if(vis[fa[x]])return fa[x];
	return check(fa[x]);
}
ll solve(int x){
	ll re=0;
	mark=check(x);
	dfs(mark);
	re=dp[mark][0];
	mark=fa[mark];
	dfs(mark);
	re=max(re,dp[mark][0]);
	return re;
}
int main(){
	scanf("%d",&n);
	for(int u,i=1;i<=n;i++){
		scanf("%lld%d",&w[i],&u);
		fa[i]=u;
		g[u].push_back(i);
	}
	for(int i=1;i<=n;i++){
		if(!vis[i]){
			ans+=solve(i);
		}
	}
	cout<<ans;
	return 0;
}

2024/9/11 21:17
加载中...