一直WA了两个点,和答案很接近了...但是代码太长好难调啊...
#include<bits/stdc++.h>
#define ll long long
using namespace std;
int N, M, a[100010];
int cnt, head[100010];
int LN, dep[100010], f[100010][25];
bool used[300010];
ll sum, ans = 2e15;
struct pre_edge{
int x, y;
ll z;
}pe[300010];
struct edge{
int to, next;
ll val;
}e[600010];
struct M4{
ll m[5] = {-1e9, -1e9, -1e9, -1e9, -1e9}, m1, m2 = -1e9;
void msort(){
sort(m + 1, m + 5);
m1 = m[4];
for(int i = 4; i >= 0; i--){
if(m[i] < m1){
m2 = m[i];
return;
}
}
}
}maxx[100010][25];
void add_edge(int u, int v, ll w){
e[++cnt].to = v;
e[cnt].val = w;
e[cnt].next = head[u];
head[u] = cnt;
}
bool cmp(pre_edge a, pre_edge b){
return a.z < b.z;
}
int find(int x){
while(x != a[x]){
x = a[x] = a[a[x]];
}
return x;
}
void kruskal(){
sort(pe + 1, pe + M + 1, cmp);
int i, E = 0;
for(i = 1; i <= M; i++){
int ax = find(pe[i].x), ay = find(pe[i].y);
if(ax == ay){
continue;
}
sum += pe[i].z;
used[i] = 1;
a[ax] = ay;
add_edge(pe[i].x, pe[i].y, pe[i].z);
add_edge(pe[i].y, pe[i].x, pe[i].z);
if(++E == N - 1){
return;
}
}
}
void init(int u, int fa, ll val){
dep[u] = dep[fa] + 1;
for(int i = 1; i <= 4; i++){
maxx[u][0].m[i] = val;
}
maxx[u][0].msort();
for(int i = 0; i < LN; i++){
f[u][i + 1] = f[f[u][i]][i];
maxx[u][i + 1].m[1] = maxx[u][i].m1;
maxx[u][i + 1].m[2] = maxx[u][i].m2;
maxx[u][i + 1].m[3] = maxx[f[u][i]][i].m1;
maxx[u][i + 1].m[4] = maxx[f[u][i]][i].m2;
maxx[u][i + 1].msort();
}
for(int i = head[u]; i; i = e[i].next){
int v = e[i].to;
if(v == fa){
continue;
}
f[v][0] = u;
init(v, u, e[i].val);
}
}
ll LCA(int x, int y, int dis){
ll xm1 = -1e9, ym1 = -1e9, xm2 = -1e9, ym2 = -1e9;
if(dep[x] < dep[y]){
swap(x, y);
}
for(int i = LN; i >= 0; i--){
if(dep[f[x][i]] > dep[y]){
xm1 = max(xm1, maxx[x][i].m1);
xm2 = max(xm2, maxx[x][i].m2);
x = f[x][i];
}
if(x == y){
return (dis > xm1 ? xm1 : xm2);
}
}
for(int i = LN; i >= 0; i--){
if(f[x][i] != f[y][i]){
xm1 = max(xm1, maxx[x][i].m1);
xm2 = max(xm2, maxx[x][i].m2);
ym1 = max(ym1, maxx[y][i].m1);
ym2 = max(ym2, maxx[y][i].m2);
x = f[x][i];
y = f[y][i];
}
}
xm1 = max(xm1, maxx[x][0].m1);
xm2 = max(xm2, maxx[x][0].m2);
ym1 = max(ym1, maxx[y][0].m1);
ym2 = max(ym2, maxx[y][0].m2);
ll t[5] = {-1e9, xm1, xm2, ym1, ym2};
sort(t + 1, t + 5);
for(int i = 4; i >= 0; i--){
if(t[i] < dis){
return t[i];
}
}
}
int main(){
freopen("tree4.in", "r", stdin);
scanf("%d %d", &N, &M);
LN = log2(N) + 1;
for(int i = 1; i <= N; a[i] = i, i++);
for(int i = 1; i <= M; i++){
scanf("%d %d %lld", &pe[i].x, &pe[i].y, &pe[i].z);
}
kruskal();
init(1, 0, 0);
for(int i = 1; i <= M; i++){
if(!used[i]){
int tx = pe[i].x, ty = pe[i].y, tz = pe[i].z;
int add = LCA(tx, ty, tz);
ans = min(ans, sum - add + tz);
}
}
cout << ans;
fclose(stdin);
return 0;
}
loj上的结果: