我和同学都A了这一题。。。
可是我和他测试的一组数据并不相同。。。
有大佬知道为什么吗
样例:
4 5
4 1 2
3 2 3
2 1 5
4 3 5
4 2 5
我的代码(码风很丑,因为是考试题,还请见谅)
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+1000,MAX_log=30;
struct node
{
int from,to,val;
}edge[N*2],ci_edge[N*2];
struct ss
{
int to,val;
};
int n,m,tot,min_ans,tot1,ans;
int fa[N],deep[N];
int f[35][N];
int b[35][N],sb[35][N];
vector<ss>g[N];
bool vis[N];
bool mycmp(node x,node y)
{
return x.val<y.val;
}
void fre()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
}
int find(int x)
{
if(x==fa[x])return x;
return fa[x]=find(fa[x]);
}
void kruskal()
{
for(int i=1;i<=n;i++)
fa[i]=i;
sort(edge+1,edge+tot+1,mycmp);
//for(int i=1;i<=m;i++)
//cout<<edge[i].from<<' '<<edge[i].to<<endl;
for(int i=1;i<=m;i++)
{
int x=find(edge[i].from);
int y=find(edge[i].to);
if(x==y)
{
ci_edge[++tot1]=edge[i];
continue;
}
else
{
// cout<<edge[i].from<<' '<<edge[i].to<<endl;
fa[x]=y;
min_ans+=edge[i].val;
//cout<<min_ans<<endl;
g[edge[i].from].push_back(ss{edge[i].to,edge[i].val});
g[edge[i].to].push_back(ss{edge[i].from,edge[i].val});
}
}
}
void bfs()
{
queue<int>q;
q.push(1);deep[1]=1;
vis[1]=1;
while(!q.empty())
{
int x=q.front();
q.pop();
for(int i=0;i<g[x].size();i++)
{
int y=g[x][i].to;
if(deep[y])continue;
if(vis[y])continue;
vis[y]=1;
deep[y]=deep[x]+1;
//cout<<deep[y]<<endl;
f[0][y]=x;
// cout<<f[0][y]<<endl;
b[0][y]=g[x][i].val;
// cout<<b[0][2]<<endl;
sb[0][y]=-(1e18);
for(int j=1;j<=MAX_log;j++)
{
f[j][y]=f[j-1][f[j-1][y]];
b[j][y]=max(b[j-1][y],b[j-1][f[j-1][y]]);
if(b[j-1][y]==b[j-1][f[j-1][y]])
sb[j][y]=max(sb[j-1][y],sb[j-1][f[j-1][y]]);
if(b[j-1][y]<b[j-1][f[j-1][y]])
sb[j][y]=max(b[j-1][y],sb[j-1][f[j-1][y]]);
if(b[j-1][y]>b[j-1][f[j-1][y]])
sb[j][y]=max(sb[j-1][y],b[j-1][f[j-1][y]]);
// cout<<y<<' '<<j<<' '<<f[j-1][y]<<' '<<f[j][y]<<' '<<b[j][y]<<' '<<sb[j][y]<<' '<<b[j-1][f[j-1][y]]<<endl;
}
q.push(y);
}
}
}
pair<int,int> lca(int x,int y,int z)
{
// cout<<x<<' '<<y<<endl;
int bigsum=-(1e18),smallsum=-(1e18);
if(deep[x]>deep[y])swap(x,y);
// cout<<deep[x]<<' '<<deep[y]<<endl;
for(int i=MAX_log;i>=0;i--)
if(deep[f[i][y]]>=deep[x])
{
if(bigsum<b[i][y])
bigsum=b[i][y];
else if(bigsum>b[i][y]&&b[i][y])smallsum=max(smallsum,b[i][y]);
if(sb[i][y])smallsum=max(smallsum,sb[i][y]);
y=f[i][y];
}
if(x==y)return make_pair(bigsum,smallsum);
for(int i=20;i>=0;i--)
if(f[i][x]!=f[i][y])
{
// cout<<i<<' '<<x<<' '<<i<<' '<<y<<endl;
int one=b[i][x],two=b[i][y];
if(one<two)swap(one,two);
if(bigsum<one)bigsum=max(bigsum,one);
else if(bigsum>one&&one)smallsum=max(smallsum,one);
if(two||sb[i][x]||sb[i][y])smallsum=max(smallsum,max(two,max(sb[i][x],sb[i][y])));
x=f[i][x],y=f[i][y];
// cout<<bigsum<<' '<<smallsum<<endl;
}
if(b[0][x]>bigsum)bigsum=b[0][x];
else if(bigsum>b[0][x])smallsum=max(b[0][x],smallsum);
return make_pair(bigsum,smallsum);
}
signed main()
{
//fre();
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;i++)
fa[i]=i;
for(int i=1;i<=m;i++)
{
int x,y,z;
scanf("%lld%lld%lld",&x,&y,&z);
if(x==y)continue;
edge[++tot]=node{x,y,z};
}
kruskal();
bfs();
// cout<<min_ans<<endl;
int MD=1e18;
// for(int i=1;i<=n;i++)
// cout<<deep[i]<<' ';
// cout<<endl;
for(int i=1;i<=tot1;i++)
{
// cout<<ci_edge[i].to<<endl;
pair<int,int> k=lca(ci_edge[i].from,ci_edge[i].to,ci_edge[i].val);
// cout<<ci_edge[i].from<<' '<<ci_edge[i].to<<endl;
// cout<<k.first<<' '<<k.second<<' '<<ci_edge[i].val<<endl;
if(ci_edge[i].val>k.first)MD=min(MD,min_ans-k.first+ci_edge[i].val);
else MD=min(MD,min_ans-k.second+ci_edge[i].val);
// cout<<MD<<endl;
}
printf("%lld\n",MD);
return 0;
}
同学的代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e5+10,T=25;
int n,m;
int mx1[N][T],mx2[N][T],f[N][T];
struct node{int x,y,v;}e[N];
int fa[N],d[N];
int sum=0,ans=2e18;
bool cmp(node x,node y){return x.v<y.v;}
struct ed{int to,val;};
vector<ed>edge[N];
int max1,max2;
bool vis[N];
int cnt=0;
int find(int x){return fa[x]=fa[x]==x?x:find(fa[x]);}
void dfs(int x,int fa)
{
for(int i=0;i<edge[x].size();i++)
{
int y=edge[x][i].to,v=edge[x][i].val;
if(y==fa)continue;
d[y]=d[x]+1;
f[y][0]=x;
mx1[y][0]=v;
for(int j=1;j<=23;j++)
{
f[y][j]=f[f[y][j-1]][j-1];
mx1[y][j]=max(mx1[y][j-1],mx1[f[y][j-1]][j-1]);
if(mx1[y][j-1]==mx1[f[y][j-1]][j-1])mx2[y][j]=max(mx2[y][j-1],mx2[f[y][j-1]][j-1]);
if(mx1[y][j-1]< mx1[f[y][j-1]][j-1])mx2[y][j]=max(mx1[y][j-1],mx2[f[y][j-1]][j-1]);
if(mx1[y][j-1]> mx1[f[y][j-1]][j-1])mx2[y][j]=max(mx2[y][j-1],mx1[f[y][j-1]][j-1]);
}
dfs(y,x);
}
return;
}
void lca(int x,int y)
{
if(d[x]>d[y])swap(x,y);
for(int i=23;i>=0;i--)
{
if(d[f[y][i]]>=d[x])
{
if(mx1[y][i]>max1)max1=mx1[y][i];
else if(mx1[y][i]<max1)max2=max(max2,mx1[y][i]);
max2=max(max2,mx2[y][i]);
y=f[y][i];
}
}
if(x==y)return;
for(int i=23;i>=0;i--)
{
if(f[x][i]!=f[y][i])
{
if(mx1[x][i]>max1)max1=mx1[x][i];
else if(mx1[x][i]<max1)max2=max(max2,mx1[x][i]);
max2=max(max2,mx2[x][i]);
if(mx1[y][i]>max1)max1=mx1[y][i];
else if(mx1[y][i]<max1)max2=max(max2,mx1[y][i]);
max2=max(max2,mx2[y][i]);
x=f[x][i],y=f[y][i];
}
}
if(mx1[x][0]>max1)max1=mx1[x][0];
else if(mx1[x][0]<max1)max2=max(max2,mx1[x][0]);
max2=max(max2,mx2[x][0]);
if(mx1[y][0]>max1)max1=mx1[y][0];
else if(mx1[y][0]<max1)max2=max(max2,mx1[y][0]);
max2=max(max2,mx2[y][0]);
return;
}
signed main()
{
// freopen("1.in","r",stdin);
// freopen("2.out","w",stdout);
scanf("%lld%lld",&n,&m);
memset(mx2,-0x7f,sizeof(mx2));
for(int i=1;i<=m;i++)
scanf("%lld%lld%lld",&e[i].x,&e[i].y,&e[i].v);
sort(e+1,e+m+1,cmp);
for(int i=1;i<=n;i++)fa[i]=i;
for(int i=1;i<=m;i++)
{
int tx=find(e[i].x),ty=find(e[i].y);
if(tx!=ty)
{
cout<<e[i].x<<' '<<e[i].y<<endl;
vis[i]=1;
fa[ty]=tx;
sum+=e[i].v;
edge[e[i].x].push_back({e[i].y,e[i].v});
edge[e[i].y].push_back({e[i].x,e[i].v});
if(++cnt==n-1)break;
}
}
d[1]=1;dfs(1,0);
cout<<sum<<endl;
for(int i=1;i<=m;i++)
{
int tx=e[i].x,ty=e[i].y,v=e[i].v;
if(vis[i]==0)
{
max1=-2e18;max2=-2e18;
lca(tx,ty);
// cout<<max1<<' '<<max2<<' '<<v<<endl;
if(v>max1)ans=min(ans,sum-max1+v);
else ans=min(ans,sum-max2+v);
}
}
printf("%lld\n",ans);
return 0;
}