求助,WA三个点
查看原帖
求助,WA三个点
308157
hansqin楼主2021/8/9 17:27

吸氧WA了2,9,10三个点 下面是代码

#include<bits/stdc++.h>
using namespace std;
#define N 114514
#define ll long long
ll a[N];
ll n,m,mod; 
struct Segment_Tree{
	#define L(x) x<<1
	#define R(x) x<<1|1
	int x,y;
	ll k,tr[N<<2],tag[N<<2],tag2[N<<2];
	void push_up(int T){
		tr[T]=tr[L(T)]+tr[R(T)];
	}
	void push_down(int T,int l,int r){
		int mid=l+r>>1;
		tag2[L(T)]=(tag2[L(T)]*tag2[T])%mod;
		tag[L(T)]=(tag[L(T)]*tag2[T])%mod;
		tr[L(T)]=(tr[L(T)]*tag2[T])%mod;
		tag2[R(T)]=(tag2[R(T)]*tag2[T])%mod;
		tag[R(T)]=(tag[R(T)]*tag2[T])%mod;
		tr[R(T)]=(tr[R(T)]*tag2[T])%mod;	
		tag2[T]=1;
		tag[L(T)]=(tag[L(T)]+tag[T])%mod;
		tr[L(T)]=(tr[L(T)]+tag[T]*(mid-l+1))%mod;
		tag[R(T)]=(tag[R(T)]+tag[T])%mod;
		tr[R(T)]=(tr[R(T)]+tag[T]*(r-mid))%mod;	
		tag[T]=0;
	}
	void build(int T,int l,int r){
		tag[T]=0;
		tag2[T]=1;
		if(l==r){
			tr[T]=a[l];
			return ;
		}
		int mid=l+r>>1;
		build(L(T),l,mid);
		build(R(T),mid+1,r);
		push_up(T);
	}
	void add(int T, int l, int r){
		push_down(T,l,r);
		if(l>=x&&r<=y){
			tag[T]=k%mod;
			tr[T]+=(r-l+1)*k%mod;
			return ;
		}
		int mid=l+r>>1;
		if(y<=mid) add(L(T),l,mid);
		else if(x>mid) add(R(T),mid+1,r);
		else add(L(T),l,mid),add(R(T),mid+1,r);
		push_up(T);
	}
	void mul(int T,int l,int r)
	{
		
		if(l>=x&&r<=y)
		{
			tr[T]=(tr[T]*k)%mod;
			tag[T]=(tag[T]*k)%mod;
			tag2[T]=(tag2[T]*k)%mod;
			return ; 
		}
		push_down(T,l,r);
		int mid=l+r>>1;
		if(y<=mid) mul(L(T),l,mid);
		else if(x>mid) mul(R(T),mid+1,r);
		else mul(L(T),l,mid),mul(R(T),mid+1,r);
		push_up(T);
	}
	ll ask(int T,int l,int r){
		push_down(T,l,r);
		if(l>=x&&r<=y) return tr[T];
		int mid=l+r>>1; 
		if(y<=mid) return ask(L(T),l,mid);
		else if(x>mid) return ask(R(T),mid+1,r);
		else return ask(L(T),l,mid) + ask(R(T),mid+1,r);
	}
	void solve(){
		build(1,1,n);
		while(m--){
			int opt;
			cin>>opt>>x>>y;
			if(opt==1){
				cin>>k;
				mul(1,1,n);
			} 
			else if(opt==2){
				cin>>k;
				add(1,1,n);
			}
			else cout<<ask(1,1,n)%mod<<endl;
		}
	}
	#undef L
	#undef R
}str; 
int main()
{
	cin>>n>>m>>mod;
	for(int i=1;i<=n;i++)cin>>a[i];
	str.solve();
	return 0;
}
2021/8/9 17:27
加载中...