30pts 求助
查看原帖
30pts 求助
536005
owls楼主2021/12/19 16:36
#include<iostream>
#include<cstdio>
#include<stack>
#include<queue>
#include<cstring>
#include<algorithm>

#define MAX 2000005
#define INF 0x3f3f3f3f
#define ll long long
#define N 3000	

using namespace std;

int n;
int arr[MAX],dp[MAX][2],f[MAX][2];
ll res[MAX][2][2],ans;
int cnt,cnt2,flag;
int head[MAX],head2[MAX];
struct node{
	int to,next;
}edge[MAX],edge2[MAX];
void add(int from,int to){
	edge[++cnt].to=to;
	edge[cnt].next=head[from];
	head[from]=cnt;
}
void add2(int from,int to){
	edge2[++cnt2].to=to;
	edge2[cnt2].next=head2[from];
	head2[from]=cnt2;
}

stack<int> s;
queue<int> q;

int num;
int dfn[MAX],low[MAX],belong[MAX],vis[MAX],sum[MAX];
void tarjan(int x){
	dfn[x]=low[x]=++num;
	vis[x]=1;
	s.push(x);
	for(int i=head2[x];i;i=edge2[i].next){
		int y=edge2[i].to;
		if(!vis[y]){
			tarjan(y);
			low[x]=min(low[x],low[y]);
		}else if(vis[y]==1) low[x]=min(low[x],dfn[y]);
	}
	if(low[x]==dfn[x]){
		int j=-1;
		while(j!=x){
			j=s.top();
			s.pop();
			vis[j]=2;
			belong[j]=x;
		}
	}
}

int dfs(int x,int fa){
	for(int i=head[x];i;i=edge[i].next){
		int y=edge[i].to;
		if(y==fa || vis[y]==1) continue;
		dfs(y,x); 
	}
	for(int i=head[x];i;i=edge[i].next){
		int y=edge[i].to;
		if(y==fa || vis[y]==1) continue;      
		dp[x][0]+=max(dp[y][0],dp[y][1]);
		dp[x][1]+=dp[y][0];        
	}
}

int cnt3;
int h[MAX];
int main(){
	scanf("%d",&n);
	for(int i=1,a;i<=n;i++){
		scanf("%d%d",&arr[i],&a);
		dp[i][1]=arr[i];
		add(i,a);
		add(a,i);
		add2(i,a);
	}
//	cout<<cnt2<<" ";
	for(int i=1;i<=n;i++){
		if(!vis[i]) tarjan(i);
	}
//	for(int i=1;i<=n;i++) cout<<belong[i]<<" ";
	for(int i=1;i<=n;i++){
		sum[belong[i]]++;
	}
//	cout<<flag<<endl;
	for(int i=1;i<=n;i++) {
		if(sum[belong[i]]>1) vis[i]=1,h[++cnt3]=i;
		else vis[i]=0;
	}
	for(int i=1;i<=cnt3;i++) {
		dfs(h[i],0);
		f[i+cnt3][0]=f[i][0]=dp[h[i]][0];
		f[i+cnt3][1]=f[i][1]=dp[h[i]][1];
	}
//	for(int i=1;i<=cnt3*2;i++) 
//		cout<<i<<" :"<<f[i][0]<<" "<<f[i][1]<<endl;
//	cout<<cnt3<<"++++++++"<<endl;
		res[1][0][0]=f[1][0];
		res[1][1][1]=f[1][1];
	for(int i=2;i<=cnt3;i++){
			res[i][0][0]=max(res[i-1][1][0],res[i-1][0][0])+f[i][0];
			res[i][1][0]=res[i-1][0][0]+f[i][1];
			res[i][0][1]=max(res[i-1][1][1],res[i-1][0][1])+f[i][0];
			res[i][1][1]=res[i-1][0][1]+f[i][1];
	}
//		for(int j=1;j<=cnt3;j++) cout<<res[j][0][0]<<" "<<res[j][1][0]<<" "<<res[j][0][1]<<" "<<res[j][1][1]<<endl;
	ans=max(max(ans,res[cnt3][1][0]),	max(res[cnt3][0][1],res[cnt3][0][0]));
	printf("%lld",ans);
} 
2021/12/19 16:36
加载中...