关于DP的问题
  • 板块P2656 采蘑菇
  • 楼主Akoasm_X
  • 当前回复3
  • 已保存回复3
  • 发布时间2021/4/18 21:40
  • 上次更新2023/11/5 00:21:49
查看原帖
关于DP的问题
250174
Akoasm_X楼主2021/4/18 21:40

拓扑代码中,这么写DP方程是AC的

	while(!q.empty())
	{
		int x = q.front();q.pop();
		for(int i=head[x];i;i=data[i].net)
		{
			int to = data[i].to ;
			int va = data[i].va ;
			du[to]--;
			dp[to] = max(dp[to],dp[x]+va+val[to]);
			if(du[to]==0) q.push(to);
		}
	}

仅修改方程把val [to] 提出来以后就不对了

while(!q.empty())
	{
		int x = q.front();q.pop();
		for(int i=head[x];i;i=data[i].net)
		{
			int to = data[i].to ;
			int va = data[i].va ;
			du[to]--;
			dp[to] = max(dp[to],dp[x]+va);
			if(du[to]==0)
			{
				du[to] += val[to];
				q.push(to);	
			}
		}
	}

下面是全部代码

#include<bits/stdc++.h>
using namespace std;
const int maxm = 500020;
const int maxn = 500020;
typedef long long LL;
int n,m,tim,top,cnt,str,tot;
int head[maxn],dfn[maxn],sta[maxn],low[maxn],id[maxn],du[maxn];
LL ans,val[maxn],dp[maxn];
struct node
{
	int to,net,va;
	double p;
}data[maxm<<2];
bool vis[maxn];
queue<int> q;

void add(int a,int b,int c,double d)
{
	data[++tot].net = head[a];data[tot].p = d;
	data[tot].to = b;data[tot].va = c;
	head[a] = tot;
}

void tarjan(int x)
{
	low[x] = dfn[x] = ++tim;
	sta[++top] = x;vis[x] = 1;
	for(int i=head[x];i;i=data[i].net)
	{
		int to = data[i].to ;
		if(!dfn[to]) tarjan(to),low[x] = min(low[x],low[to]);
		else if(vis[to])	low[x] = min(low[x],dfn[to]);
	}
	if(dfn[x]==low[x])
	{
		vis[x] = 0;cnt++;
		while(sta[top+1]!=x)
		{
			id[sta[top]] = cnt;
			vis[sta[top]] = 0;
			top--;
		}
	}
}

void tp()
{
	memset(dp,0xcf,sizeof(dp));
	for(int i=n+1;i<=cnt;i++)
		if(du[i]==0) q.push(i);
	dp[str] = val[str];
	while(!q.empty())
	{
		int x = q.front();q.pop();
		for(int i=head[x];i;i=data[i].net)
		{
			int to = data[i].to ;
			int va = data[i].va ;
			du[to]--;
			dp[to] = max(dp[to],dp[x]+va+val[to]); 
			if(du[to]==0) q.push(to);
		}
	}
}

int main()
{
	scanf("%d%d",&n,&m);cnt = n;
	for(int i=1;i<=m;i++)
	{
		int a,b,c;double d;
		scanf("%d%d%d%lf",&a,&b,&c,&d);
		add(a,b,c,d);
	}
	scanf("%d",&str);
	for(int i=1;i<=n;i++)	if(!dfn[i]) tarjan(i);
	for(int i=1;i<=n;i++)
	{
		for(int j=head[i];j;j=data[j].net)
		{
			int to = data[j].to ;
			int va = data[j].va ;
			if(id[to]==id[i])
			{
				while(va)
				{
					val[id[i]] += va;
					va = (va * data[j].p);
				}
			}
			else
			{
				add(id[i],id[to],va,0);
				du[id[to]]++;
			}
		}
	}
	str = id[str];
	tp();
	for(int i=n+1;i<=cnt;i++)
		ans = max(ans,dp[i]);
	printf("%lld\n",ans);
	return 0;
}
2021/4/18 21:40
加载中...