拓扑代码中,这么写DP方程是AC的
while(!q.empty())
{
int x = q.front();q.pop();
for(int i=head[x];i;i=data[i].net)
{
int to = data[i].to ;
int va = data[i].va ;
du[to]--;
dp[to] = max(dp[to],dp[x]+va+val[to]);
if(du[to]==0) q.push(to);
}
}
仅修改方程把val [to] 提出来以后就不对了
while(!q.empty())
{
int x = q.front();q.pop();
for(int i=head[x];i;i=data[i].net)
{
int to = data[i].to ;
int va = data[i].va ;
du[to]--;
dp[to] = max(dp[to],dp[x]+va);
if(du[to]==0)
{
du[to] += val[to];
q.push(to);
}
}
}
下面是全部代码
#include<bits/stdc++.h>
using namespace std;
const int maxm = 500020;
const int maxn = 500020;
typedef long long LL;
int n,m,tim,top,cnt,str,tot;
int head[maxn],dfn[maxn],sta[maxn],low[maxn],id[maxn],du[maxn];
LL ans,val[maxn],dp[maxn];
struct node
{
int to,net,va;
double p;
}data[maxm<<2];
bool vis[maxn];
queue<int> q;
void add(int a,int b,int c,double d)
{
data[++tot].net = head[a];data[tot].p = d;
data[tot].to = b;data[tot].va = c;
head[a] = tot;
}
void tarjan(int x)
{
low[x] = dfn[x] = ++tim;
sta[++top] = x;vis[x] = 1;
for(int i=head[x];i;i=data[i].net)
{
int to = data[i].to ;
if(!dfn[to]) tarjan(to),low[x] = min(low[x],low[to]);
else if(vis[to]) low[x] = min(low[x],dfn[to]);
}
if(dfn[x]==low[x])
{
vis[x] = 0;cnt++;
while(sta[top+1]!=x)
{
id[sta[top]] = cnt;
vis[sta[top]] = 0;
top--;
}
}
}
void tp()
{
memset(dp,0xcf,sizeof(dp));
for(int i=n+1;i<=cnt;i++)
if(du[i]==0) q.push(i);
dp[str] = val[str];
while(!q.empty())
{
int x = q.front();q.pop();
for(int i=head[x];i;i=data[i].net)
{
int to = data[i].to ;
int va = data[i].va ;
du[to]--;
dp[to] = max(dp[to],dp[x]+va+val[to]);
if(du[to]==0) q.push(to);
}
}
}
int main()
{
scanf("%d%d",&n,&m);cnt = n;
for(int i=1;i<=m;i++)
{
int a,b,c;double d;
scanf("%d%d%d%lf",&a,&b,&c,&d);
add(a,b,c,d);
}
scanf("%d",&str);
for(int i=1;i<=n;i++) if(!dfn[i]) tarjan(i);
for(int i=1;i<=n;i++)
{
for(int j=head[i];j;j=data[j].net)
{
int to = data[j].to ;
int va = data[j].va ;
if(id[to]==id[i])
{
while(va)
{
val[id[i]] += va;
va = (va * data[j].p);
}
}
else
{
add(id[i],id[to],va,0);
du[id[to]]++;
}
}
}
str = id[str];
tp();
for(int i=n+1;i<=cnt;i++)
ans = max(ans,dp[i]);
printf("%lld\n",ans);
return 0;
}