RT
先把k条最长的路径变成0,跑完最短路后,
把不在最短路上的原先变为0的路径撤回,放在现有最短路上
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define debug cout<<"!error!";
const int N = 1e5 + 1000,M = 210,inf = 0x3f3f3f3f;
template <class T> void read(T &w){
w=0;unsigned long long f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){(w*=10)+=ch-'0';ch=getchar();}
w*=f;
}
template <class T> void write(T w){
if(w<0){putchar('-');w*=-1;}
if(w/10) write(w/10);
putchar(w%10+'0');
}
int h[N], ne[N], e[N], idx, w[N];
int d[N];
int in_queue[N];
int last[N];
int n,m,k;
int start, endd;
struct node
{
int num, val;
}id[N];
void add(int a,int b,int c)
{
e[idx] = b;
ne[idx] = h[a];
w[idx] = c;
h[a] = idx ++;
}
int cmp(node x, node y)
{
return x.val > y.val;
}
int pre[N];
void spfa(int s)
{
d[s] = 0;
memset(in_queue, 0, sizeof in_queue);
queue<int> q;
q.push(s);
in_queue[s] = 1;
while(!q.empty())
{
int x = q.front();
q.pop();
in_queue[x] = 0;
for(int i = h[x]; ~i; i = ne[i])
{
int y = e[i];
if(d[y] > d[x] + w[i])
{
d[y] = d[x] + w[i];
pre[y] = x;
if(!in_queue[y])
{
in_queue[y] = 1;
q.push(y);
}
}
}
}
}
int cmpp(int x,int y)
{
return x > y;
}
signed main()
{
read(n), read(m), read(k);
read(start), read(endd);
memset(pre, -1, sizeof pre);
memset(h, -1, sizeof h);
memset(d, inf, sizeof d);
for(int i = 0; i < m; i ++)
{
int a,b,c;
read(a), read(b), read(c);
add(a,b,c), add(b,a,c);
id[i].num = i * 2, id[i].val = c;//^1
}
sort(id, id + m, cmp);
for(int i = 0; i < k; i ++)
{
w[id[i].num] = 0, w[id[i].num ^ 1] = 0;
}
spfa(start);
int k_cnt = 0;
int temp = endd;
while(~temp)
{
int target = pre[temp];
for(int i = h[temp]; ~i; i = ne[i])
{
int y = e[i];
if(y == target && w[i] == 0)
{
k_cnt ++;
break;
}
}
temp = pre[temp];
}
k -= k_cnt;
if(k == 0)
{
printf("%lld\n", d[endd]);
exit(0);
}
//debug
/*
5 5 1
0 3
0 1 100
1 2 100
0 4 5
2 3 50
3 4 300
*/
int l_cnt = 0;
temp = endd;
while(~temp)
{
for(int i = h[temp]; ~i; i = ne[i])
{
int y = e[i];
if(y == pre[temp])
{
last[++ l_cnt] = w[i];
break;
}
}
temp = pre[temp];
}
sort(last + 1, last + 1 + l_cnt, cmpp);
int sum = 0;
for(int i = 1; i <= k; i ++)
{
sum += last[i];
}
printf("%lld\n", d[endd] - sum);
return 0;
}