本人只学了 OI-wiki 上很基础的强连通分量,请轻喷。
AC code:
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
int n,m,a[N];
vector<int> e[N];
int in_stk[N],pos[N],dfn[N],low[N],dfnn,cnt;
int stk[N],top;
vector<int> scc[N];
int sum[N];
set<int> e2[N];
int in[N],f[N];
void dfs(int u){
dfn[u]=low[u]=++dfnn;
stk[++top]=u;in_stk[u]=1;
for(auto x:e[u]){
if(!dfn[x]){
dfs(x);
low[u]=min(low[u],low[x]);
}else if(in_stk[x]) low[u]=min(low[u],dfn[x]);
}
if(dfn[u]==low[u]){
cnt++;
do{
in_stk[stk[top]]=0;
pos[stk[top]]=cnt;
scc[cnt].emplace_back(stk[top]);
}while(stk[top--]!=u);
}
}
void dfs2(int u,int fa){
if(f[fa]+sum[u]<=f[u]) return ;
f[u]=f[fa]+sum[u];
for(auto x:e2[u]) dfs2(x,u);
}
int main(){
// ios::sync_with_stdio(0);
// cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1,u,v;i<=m;i++){
cin>>u>>v;
e[u].push_back(v);
}
for(int i=1;i<=n;i++){
if(!dfn[i]) dfs(i);
}
for(int i=1;i<=cnt;i++){
for(auto x:scc[i]) sum[i]+=a[x];
}
// for(int i=1;i<=cnt;i++){
// cout<<i<<':';
// for(auto x:scc[i]) cout<<x<<' ';
// cout<<"|"<<sum[i]<<"\n";
// }
for(int i=1;i<=n;i++){
for(auto x:e[i]){
if(pos[i]!=pos[x]) e2[pos[i]].insert(pos[x]);
}
}
for(int i=1;i<=cnt;i++){
for(auto x:e2[i]) in[x]++;
}
for(int i=1;i<=cnt;i++){
if(!in[i]) dfs2(i,0);
}
int ans=0;
for(int i=1;i<=cnt;i++) ans=max(ans,f[i]);
cout<<ans<<"\n";
return 0;
}
TLE code:
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10;
int n,m,a[N];
vector<int> e[N];
int in_stk[N],pos[N],dfn[N],low[N],dfnn,cnt;
int stk[N],top;
vector<int> scc[N];
int sum[N];
set<int> e2[N];
int in[N],f[N];
void dfs(int u){
dfn[u]=low[u]=++dfnn;
stk[++top]=u;in_stk[u]=1;
for(auto x:e[u]){
if(!dfn[x]){
dfs(x);
low[u]=min(low[u],low[x]);
}else if(in_stk[x]) low[u]=min(low[u],dfn[x]);
}
if(dfn[u]==low[u]){
cnt++;
do{
in_stk[stk[top]]=0;
pos[stk[top]]=cnt;
scc[cnt].emplace_back(stk[top]);
}while(stk[top--]!=u);
}
}
void dfs2(int u,int fa){
f[u]=max(f[u],f[fa]+sum[u]);
for(auto x:e2[u]) dfs2(x,u);
}
int main(){
// ios::sync_with_stdio(0);
// cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1,u,v;i<=m;i++){
cin>>u>>v;
e[u].push_back(v);
}
for(int i=1;i<=n;i++){
if(!dfn[i]) dfs(i);
}
for(int i=1;i<=cnt;i++){
for(auto x:scc[i]) sum[i]+=a[x];
}
// for(int i=1;i<=cnt;i++){
// cout<<i<<':';
// for(auto x:scc[i]) cout<<x<<' ';
// cout<<"|"<<sum[i]<<"\n";
// }
for(int i=1;i<=n;i++){
for(auto x:e[i]){
if(pos[i]!=pos[x]) e2[pos[i]].insert(pos[x]);
}
}
for(int i=1;i<=cnt;i++){
for(auto x:e2[i]) in[x]++;
}
for(int i=1;i<=cnt;i++){
if(!in[i]) dfs2(i,0);
}
int ans=0;
for(int i=1;i<=cnt;i++) ans=max(ans,f[i]);
cout<<ans<<"\n";
return 0;
}
可以发现,我先把每个强连通分量求出来,然后再 DP 求最大和。没加剪枝 t 了,加了后就 a 了。
我想知道我这个算法的正确性。
如不正确,请大佬 hack,我将不胜感激。