不太会用乘法的懒标记,求教!
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define int long long
const int N=5e5+10;
int n,m,mod;
struct yee
{
int l,r;
int v,lazy,la;
}f[4*N];
int left(int rt)
{
return 2*rt;
}
int right(int rt)
{
return 2*rt+1;
}
void push_up(int rt)
{
f[rt].v=f[left(rt)].v+f[right(rt)].v;
}
void push_down(int rt)
{
int lazy=f[rt].lazy,la=f[rt].la;
while(1)//加
{
int l=f[left(rt)].l,r=f[left(rt)].r;
f[left(rt)].v+=(r-l+1)*lazy;
f[left(rt)].v%=mod;
l=f[right(rt)].l,r=f[right(rt)].r;
f[right(rt)].v+=(r-l+1)*lazy;
f[right(rt)].v%=mod;
f[left(rt)].lazy+=lazy;
f[right(rt)].lazy+=lazy;
f[left(rt)].lazy%=mod;
f[right(rt)].lazy%=mod;
f[rt].lazy=0;
break;
}
while(1)//乘
{
int l=f[left(rt)].l,r=f[left(rt)].r;
f[left(rt)].v*=la,f[left(rt)].v+=(la-1)*lazy,f[left(rt)].v%=mod;
f[left(rt)].la*=la,f[left(rt)].la%=mod,f[left(rt)].la=max((long long)1,f[left(rt)].la);
f[right(rt)].v*=la,f[right(rt)].v+=(la-1)*lazy,f[right(rt)].v%=mod;
f[right(rt)].la*=la,f[right(rt)].la%=mod,f[right(rt)].la=max((long long)1,f[right(rt)].la);
f[rt].la=1;
return;
}
}
void build(int rt,int l,int r)
{
f[rt].l=l;
f[rt].r=r;
f[rt].lazy=f[rt].v=0;
f[rt].la=1;
if(l==r)
return;
int mid=(l+r)/2;
build(left(rt),l,mid);
build(right(rt),mid+1,r);
}
void add(int rt,int ql,int qr,int v)//加
{
int l=f[rt].l,r=f[rt].r;
int mid=(l+r)/2;
if(ql<=l&&r<=qr)
{
f[rt].v+=(r-l+1)*v;
f[rt].v%=mod;
f[rt].lazy+=v;
f[rt].lazy%=mod;
return;
}
push_down(rt);
if(mid>=ql)
add(left(rt),ql,qr,v);
if(mid+1<=qr)
add(right(rt),ql,qr,v);
push_up(rt);
}
void u(int rt,int ql,int qr,int v)//乘
{
v%=mod;
int l=f[rt].l,r=f[rt].r;
int mid=(l+r)/2;
if(ql<=l&&r<=qr)
{
f[rt].v*=v;
f[rt].v%=mod;
f[rt].la*=v;
f[rt].la%=mod;
return;
}
push_down(rt);
if(mid>=ql)
u(left(rt),ql,qr,v);
if(mid+1<=qr)
u(right(rt),ql,qr,v);
push_up(rt);
}
int find(int rt,int ql,int qr)
{
int l=f[rt].l,r=f[rt].r,ans=0;
int mid=(l+r)/2;
if(ql<=l&&r<=qr)
return f[rt].v;
push_down(rt);
if(ql<=mid)
ans+=find(left(rt),ql,qr);
if(mid+1<=qr)
ans+=find(right(rt),ql,qr);
push_up(rt);
ans%=mod;
return ans;
}
signed main()
{
scanf("%lld %lld %lld",&n,&m,&mod);
build(1,1,n);
for(int i=1;i<=n;i++)
{
int a;
scanf("%lld",&a);
add(1,i,i,a);
}
while(m--)
{
int o,op,p,p1=0;
scanf("%lld %lld %lld",&o,&op,&p);
if(o==1)
{
scanf("%d",&p1);
u(1,op,p,p1);//根,左右端点和乘的值
}
else if(o==2)
{
scanf("%d",&p1);
add(1,op,p,p1);//根,左右端点和加的值
}
else
printf("%lld\n",find(1,op,p));
}
return 0;
}