WA最后一个点求助
查看原帖
WA最后一个点求助
506081
BlueSky_楼主2021/8/11 22:34
#include <bits/stdc++.h>
#define N 100010
#define M 300030
#define ll long long
#define INF 99999999999999999
using namespace std;

struct Node
{
	ll u, v, d;
} a[M];

bool cmp (Node x, Node y)
{
	return x.d < y.d;
}

ll n, m, f[N][50], fa[N], val, to[M], nxt[M], dis[M], h[N], cnt, t[N][50][3];
ll t1, t2, dep[N], ans = INF;
bool p[M];

int fin (int x)
{
	if (fa[x] == x) return x;
	fa[x] = fin(fa[x]);
	return fa[x];
}

void add (int u, int v, ll d)
{
	to[++cnt] = v;
	nxt[cnt] = h[u];
	dis[cnt] = d;
	h[u] = cnt;
}

void dfs (int x)
{
	int i = 1;
	dep[x] = dep[f[x][0]] + 1;
	while (f[f[x][i-1]][i-1])
	{
		f[x][i] = f[f[x][i-1]][i-1];
		ll l[10] = {t[f[x][i-1]][i-1][1], t[x][i-1][1], t[f[x][i-1]][i-1][2], t[x][i-1][2]};
		t1 = t2 = -INF;
		for (int j = 0; j < 4; j++)
			if (l[j] > t1) t2 = t1, t1 = l[j];
			else if (l[j] > t2 && l[j] != t1) t2 = l[j];
		t[x][i][1] = t1, t[x][i][2] = t2;
		i++;
	}
	for (int j = h[x]; j; j = nxt[j])
	{
		if (to[j] == f[x][0]) continue;
		f[to[j]][0] = x;
		t[to[j]][0][1] = dis[j];
		dfs (to[j]);
	}
}

void LCA (int u, int v)
{
	t1 = t2 = -INF;
	if (dep[v] > dep[u]) swap (u, v);
	for (int i = 18; i >= 0; i--)
	{
		if (dep[f[u][i]] >= dep[v])
		{
			ll l[10] = {t[u][i][1], t[u][i][2]};
			for (int j = 0; j < 2; j++)
				if (l[j] > t1) t2 = t1, t1 = l[j];
				else if (l[j] > t2 && l[j] != t1) t2 = l[j];
			u = f[u][i];
		}
	}
	for (int i = 18; i >= 0; i--)
	{
		if (f[u][i] != f[v][i])
		{
			ll l[10] = {t[u][i][1], t[v][i][1], t[u][i][2], t[v][i][2]};
			for (int j = 0; j < 4; j++)
				if (l[j] > t1) t2 = t1, t1 = l[j];
				else if (l[j] > t2 && l[j] != t1) t2 = l[j];
			u = f[u][i], v = f[v][i];
		}
	}
	ll l[10] = {t[u][0][1], t[v][0][1]};
	for (int j = 0; j < 2; j++)
		if (l[j] > t1) t2 = t1, t1 = l[j];
		else if (l[j] > t2 && l[j] != t1) t2 = l[j];
}

int main ()
{
	memset (t, 245, sizeof (t));
	cin >> n >> m;
	for (int i = 1; i <= m; i++)
	{
		cin >> a[i].u >> a[i].v >> a[i].d;
	}
	sort (a+1, a+m+1, cmp);
	for (int i = 1; i <= n; i++) fa[i] = i;
	for (int i = 1; i <= m; i++)
	{
		int u = fin(a[i].u), v = fin(a[i].v);
		if (u != v)
		{
			val += a[i].d;
			add (a[i].u, a[i].v, a[i].d);
			add (a[i].v, a[i].u, a[i].d);
			fa[u] = v;	//合并 
			p[i] = 1;
		}
	}
	dfs (1);
	for (int i = 1; i <= m; i++)
	{
		if (p[i] == 1) continue;
		LCA (a[i].u, a[i].v);
		if (t1 != a[i].d)
		{
			ans = min (ans, val + a[i].d - t1);
		} else if (t2 != t1)
		{
			ans = min (ans, val + a[i].d - t2);
		}
	}
	cout << ans << endl;
	return 0;
}

2021/8/11 22:34
加载中...