只过了m=2的和最后一个点
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+10;
int n,m,a[N],sum[N],dp[N],s[N];
struct a1{
int l,r,k;
}z[N];int L,R;
int G(int i,int j){return (sum[j]-sum[i]+1)*(sum[j]-sum[i]+1)+dp[i];}
int Get(int i,int j){
int l=z[i].l,r=n;
while(l<=r){
int mid=(l+r)>>1;
if(G(z[i].k,mid)>G(j,mid))r=mid-1;
else l=mid+1;
}
return r;
}
int check(int k){
fill(dp+1,dp+1+n,1e18);fill(s,s+1+n,0);
L=1,R=1;z[1]={1,n,0};
for(int q=1;q<=n;q++){
while(L<=R&&z[L].r<q)L++;
if(L<=R)z[L].l=q;
while(L<=R&&G(z[R].k,z[R].l)>G(q,z[R].l))R--;
dp[q]=G(z[L].k,q)+k;s[q]=s[z[L].k]+1;
if(L>R)z[++R]={q,n,q};
else if(G(z[R].k,n)>G(q,n)){
z[R].r=Get(R,q);
++R;z[R]={z[R-1].r+1,n,q};
}
}
return s[n]<=m;
}
signed main(){
cin>>n>>m;
for(int q=1;q<=n;q++)cin>>a[q],sum[q]=sum[q-1]+a[q];
int l=0,r=1e18;
while(l<=r){
int mid=(l+r)>>1;
if(check(mid))r=mid-1;
else l=mid+1;
}
check(l);
cout<<dp[n]-l*m;
return 0;
}