#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 500005;
ll dp[N], pr[N], l = 0, r = 5e11;
int a[N], s[N], n, m, hd, tl;
struct D{
int x,k;
D (int x = 0, int k = 0) : x(x), k(k){}
} q[N];
ll cmp(int l, int r){
int p = (l + r + 1)>>1;
return (pr[r] - pr[p]) - 1LL * a[p] * (r - p) + 1LL * a[p] * (p - l) - (pr[p] - pr[l]);
}
int f(int i, int j){
int l = j,r = n + 1;
while (l < r - 1){
int m = (l + r)>>1;
if (dp[i] + cmp(i, m) < dp[j] + cmp(j, m))l = m;
else r = m;
}
return r;
}
bool chk(ll mid){
hd = 1, tl = 0;
q[++tl] = D(0, n + 1);
for (int i = 1;i <= n; i++){
while (hd < tl && q[hd].k <= i) hd++;
int x = q[hd].x;
dp[i] = dp[x] + cmp(x, i) + mid;
s[i] = s[x] + 1;
while (hd < tl && f(q[tl].x, i) <= q[tl - 1].k) tl--;
q[tl].k = f(q[tl].x, i);
q[++tl] = D(i, n + 1);
}
return s[n] >= m;
}
int main(){
scanf("%d%d", &n, &m);
for (int i = 1;i <= n; i++) scanf("%d", a + i);
sort (a + 1, a + n + 1);
for (int i = 1;i <= n; i++) pr[i] = pr[i - 1] + a[i];
while (l < r - 1) {
ll m = (l + r)>>1;
if (chk(m)) l = m;
else r = m;
}
chk(l);
printf("%lld", dp[n] - 1LL * m * l);
return 0;
}
最后两个样例超时