wa了第3个点,错误有点无从查起,有做过这道题的人可以帮忙看看代码吗?
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
#include<cmath>
using namespace std;
#define int long long
#define R register
#define ld double
inline int read(){
int a=0,b=1;char c=getchar();
while(!isdigit(c)){if(c=='-')b=-1;c=getchar();}
while(isdigit(c)){a=a*10+c-'0';c=getchar();}
return a*b;
}
const int N=6e5+50;
struct node{
int l,r,dat;
}t[N*24];
priority_queue<pair<int,int> >q;
int n,m,k,ll,rr,s[N],b[N],cnt,rt[N],wh[N],ans;
void add(int &p,int pre,int l,int r,int k){
p=++cnt;
t[p].l=t[pre].l;t[p].r=t[pre].r;t[p].dat=t[pre].dat+1;
if(l==r){
return;
}
int mid=(l+r)>>1;
if(k<=mid)add(t[p].l,t[pre].l,l,mid,k);
else add(t[p].r,t[pre].r,mid+1,r,k);
}
int query(int p1,int p2,int l,int r,int k){
if(l==r)return b[l];
int mid=(l+r)>>1,sum=t[t[p2].l].dat-t[t[p1].l].dat;
if(k<=sum)return query(t[p1].l,t[p2].l,l,mid,k);
else return query(t[p1].r,t[p2].r,mid+1,r,k-sum);
}
signed main(){
n=read();k=read();
ll=read();rr=read();
s[1]=b[1]=0;n++;
for(R int i=2,a;i<=n;i++){
a=read();
b[i]=s[i]=s[i-1]+a;
}
sort(b+1,b+n+1);
m=unique(b+1,b+n+1)-b-1;
for(R int i=1;i<=n;i++){
s[i]=lower_bound(b+1,b+m+1,s[i])-b;
add(rt[i],rt[i-1],1,m,s[i]);
}
for(R int i=1,a;i<=n;i++){
if(i-ll-1<0)continue;
a=query(rt[(i-rr-1)<0?0:i-rr-1],rt[i-ll],1,m,1);
wh[i]=1;
q.push(make_pair(b[s[i]]-a,i));
}
int x,y,z;
while(k--){
x=q.top().second;y=q.top().first;
q.pop();
ans+=y;
if(wh[x]<min(x,rr)-ll+1){
wh[x]++;
z=query(rt[(x-rr-1)<0?0:x-rr-1],rt[x-ll],1,m,wh[x]);
q.push(make_pair(b[s[x]]-z,x));
}
}
printf("%lld\n",ans);
return 0;
}