80分TLE#8#9
如果用O2优化会快很多,但是#8会MLE...
估计是常数问题了,请问下面的代码(可以只关注几个修改和查询操作,这些地方个人认为是最有可能优化的)有什么地方可以优化常数的?
#include<bits/stdc++.h>
#define chd(o) (tr[tr[o].fa].ch[1]==o?1:0)
using namespace std;
const int INF=0x3f3f3f3f;
struct Node
{
int v,sz,ch[2],lsum,rsum,sub,sum,fa,rev,setc,maxv;
}tr[4500005];
int n,m,a[500005],tot,rt;
void read(int& x)
{
char c=getchar();
int f=1;x=0;
while(!isdigit(c))
{
if(c=='-')f=-1;
c=getchar();
}
while(isdigit(c))
{
x=(x<<3)+(x<<1)+(c^48);
c=getchar();
}
x*=f;
}
inline Node newnode(int v)
{
Node nd;
nd.sum=nd.v=nd.maxv=v;
nd.sub=nd.lsum=nd.rsum=max(0,v);
nd.ch[0]=nd.ch[1]=nd.rev=nd.fa=0;
nd.setc=-INF;
nd.sz=1;
return nd;
}
inline void update(int o)
{
tr[o].sum=tr[tr[o].ch[0]].sum+tr[tr[o].ch[1]].sum+tr[o].v;
tr[o].lsum=max(max(tr[tr[o].ch[0]].lsum,tr[tr[o].ch[0]].sum+tr[o].v),
tr[tr[o].ch[0]].sum+tr[o].v+tr[tr[o].ch[1]].lsum);
tr[o].rsum=max(max(tr[tr[o].ch[1]].rsum,tr[tr[o].ch[1]].sum+tr[o].v),
tr[tr[o].ch[1]].sum+tr[o].v+tr[tr[o].ch[0]].rsum);
tr[o].sub=max(max(tr[tr[o].ch[0]].sub,tr[tr[o].ch[1]].sub),
tr[tr[o].ch[0]].rsum+tr[o].v+tr[tr[o].ch[1]].lsum);
tr[o].sz=tr[tr[o].ch[0]].sz+tr[tr[o].ch[1]].sz+1;
tr[o].maxv=max(max(tr[tr[o].ch[0]].maxv,tr[tr[o].ch[1]].maxv),tr[o].v);
}
inline void SET(int o,int k)
{
tr[o].sum=tr[o].sz*k;
tr[o].lsum=tr[o].sub=tr[o].rsum=max(0,k)*tr[o].sz;
tr[o].setc=tr[o].v=tr[o].maxv=k;
tr[o].rev=0;
}
inline void pushdown(int o)
{
if(tr[o].rev)
{
if(tr[o].ch[0])tr[tr[o].ch[0]].rev^=1;
if(tr[o].ch[1])tr[tr[o].ch[1]].rev^=1;
if(tr[tr[o].ch[0]].setc>-INF)tr[tr[o].ch[0]].rev=0;
if(tr[tr[o].ch[1]].setc>-INF)tr[tr[o].ch[1]].rev=0;
swap(tr[tr[o].ch[0]].lsum,tr[tr[o].ch[0]].rsum);
swap(tr[tr[o].ch[1]].lsum,tr[tr[o].ch[1]].rsum);
swap(tr[o].ch[0],tr[o].ch[1]);
tr[o].rev=0;
}
if(tr[o].setc>-INF)
{
if(tr[o].ch[0])SET(tr[o].ch[0],tr[o].setc);
if(tr[o].ch[1])SET(tr[o].ch[1],tr[o].setc);
tr[o].setc=-INF;
}
}
void build(int& o,int l,int r)
{
int mid=(l+r)>>1;
tr[o=++tot]=newnode(a[mid]);
if(l<mid)
{
build(tr[o].ch[0],l,mid-1);
tr[tr[o].ch[0]].fa=o;
}
if(r>mid)
{
build(tr[o].ch[1],mid+1,r);
tr[tr[o].ch[1]].fa=o;
}
update(o);
}
inline void rotate(int o)
{
int p=tr[o].fa,gp=tr[p].fa;
if(!p)return ;
int d=chd(o);
tr[p].ch[d]=tr[o].ch[d^1];
tr[tr[o].ch[d^1]].fa=p;
if(gp)tr[gp].ch[chd(p)]=o;
tr[o].fa=gp;
tr[o].ch[d^1]=p;
tr[p].fa=o;
update(p);
}
int kth(int o,int k)
{
while(1)
{
pushdown(o);
if(tr[tr[o].ch[0]].sz+1==k)return o;
if(k<=tr[tr[o].ch[0]].sz)o=tr[o].ch[0];
else
{
k-=tr[tr[o].ch[0]].sz+1;
o=tr[o].ch[1];
}
}
return o;
}
void splay(int& rt,int k)
{
int o=kth(rt,k);
while(tr[o].fa)
{
int p=tr[o].fa;
if(tr[p].fa)chd(o)==chd(p)?rotate(p):rotate(o);
rotate(o);
}
update(o);//在rotate里面就不要update(o)了,在这里update(o)
rt=o;
}
void chop(int rt,int p,int& ltr,int& rtr)//此处rt不要传引用
{
splay(rt,p+1);
rtr=tr[rt].ch[1];
tr[rtr].fa=0;
tr[rt].ch[1]=0;
ltr=rt;
update(ltr);
update(rtr);
}
void split(int l,int r,int& ltr,int& mid,int& rtr)
{
chop(rt,r,mid,rtr);
chop(mid,l-1,ltr,mid);
}
void merge(int ltr,int rtr,int& nrt)
{
splay(ltr,tr[ltr].sz);
tr[ltr].ch[1]=rtr;
tr[rtr].fa=ltr;
update(rtr);
update(ltr);
nrt=ltr;
}
void flip(int l,int r)
{
int ltr,mid,rtr;
split(l,r,ltr,mid,rtr);
tr[mid].rev^=1;
swap(tr[mid].lsum,tr[mid].rsum);
merge(ltr,mid,rt);
merge(rt,rtr,rt);
}
void same(int l,int r,int k)
{
int ltr,mid,rtr;
split(l,r,ltr,mid,rtr);
SET(mid,k);
merge(ltr,mid,rt);
merge(rt,rtr,rt);
}
int qsum(int l,int r)
{
int ltr,mid,rtr,ans;
split(l,r,ltr,mid,rtr);
ans=tr[mid].sum;
merge(ltr,mid,rt);
merge(rt,rtr,rt);
return ans;
}
void remove(int l,int r)
{
int ltr,mid,rtr;
split(l,r,ltr,mid,rtr);
merge(ltr,rtr,rt);
}
int main()
{
read(n);
read(m);
for(int i=1;i<=n;i++)read(a[i]);
tr[0].maxv=-INF;
a[0]=a[n+1]=-INF;
build(rt,0,n+1);
for(int i=1;i<=m;i++)
{
char op[10];
scanf("%s",op);
if(op[0]=='I')
{
int l,len,ltr,rtr;
read(l);
read(len);
if(len)
{
chop(rt,l,ltr,rtr);
while(len--)
{
int x;
read(x);
tr[++tot]=newnode(x);
merge(ltr,tot,ltr);
}
merge(ltr,rtr,rt);
}
}
else if(op[0]=='D')
{
int l,len,r;
read(l);
read(len);
r=l+len-1;
if(len)remove(l,r);
}
else if(op[0]=='M'&&op[2]=='K')
{
int l,len,k,r;
read(l);
read(len);
r=l+len-1;
read(k);
if(len)same(l,r,k);
}
else if(op[0]=='R')
{
int l,len,r;
read(l);
read(len);
r=l+len-1;
if(len)flip(l,r);
}
else if(op[0]=='M'&&op[2]=='X')
{
int ans=tr[rt].sub;
printf("%d\n",ans?ans:tr[rt].maxv);
}
else if(op[0]=='G')
{
int l,len,r;
read(l);
read(len);
r=l+len-1;
if(!len)printf("0\n");
else printf("%d\n",qsum(l,r));
}
}
}