线段树求助
查看原帖
线段树求助
171487
cmll02楼主2020/6/6 21:27

闰土

#include <string.h>
#include <stdio.h>
inline unsigned long long read()
{
	unsigned long long num = 0; char c = getchar();
	while (c<48 || c>57)c = getchar();
	while (c >= 48 && c <= 57)num = (num << 3) + (num << 1) + (c ^ 48), c = getchar();
	return num;
}
template<typename T, typename V>
inline T max(T x, V y){ return x > y ? x : y; }
template<typename T, typename V>
inline T min(T x, V y){ return x < y ? x : y; }
template<typename T = int>
struct SegmentResult{
	T sum;
	SegmentResult(T s = 0) :sum(s){}
	SegmentResult operator=(SegmentResult q)
	{
		sum = q.sum;
		return *this;
	}
}; int P;
template<typename T = int>
class SegmentTree
{
public:
	static const int MAXN = 300086;
//private:
	T *a, *sum, *addv, *setv, NOSETVALUE;
	//T a[20], sum[20], addv[20], setv[20];
	int n, m;
	T GetLength(T t)
	{
		T p = 1;
		while (p < t)p *= 2;
		return p;
	}
public:
	SegmentTree(int n = 8) :n(GetLength(n)), m(n)
	{
		sum = new T[2 * this->n];
		addv = new T[2 * this->n];
		a = new T[2 * this->n];
		memset(sum, 0, sizeof(T) * 2 * this->n);
		memset(addv, 0, sizeof(T) * 2 * this->n);
		memset(a, 0, sizeof(T) * 2 * this->n);
		setv = new T[2 * this->n + 1];
		for (int i = 0; i <= 2 * n; i++)setv[i] = 1;
	}
	~SegmentTree(){ delete[] a; delete[] sum; delete[] addv; delete[] setv;}
private:
	T _sum;
	inline T read()
	{
		T num = 0; int f = 1; char c = getchar();
		while (c < 48 || c > 57){ if (c == '-')f = -1; c = getchar(); }
		while (c >= 48 && c <= 57)num = num * 10 + (c ^ 48), c = getchar();
		return num * f;
	}
	void readA()
	{
		for (int i = 1; i <= m; i++)
			a[i] = read();
		for (int i = m + 1; i <= n; i++)a[i] = 0;
	}
	void InitSum()
	{
		for (int i = 0; i < n; i++)sum[i + n] = a[i + 1], addv[i + n] = a[i + 1]/*, setv[i + n] = 0*/;
		for (int i = n - 1; i; i--)sum[i] = sum[i * 2] + sum[i * 2 + 1], addv[i] = 0;
	}
	void maintain(int o, int L, int R)
	{
		int lc = o * 2, rc = o * 2 + 1;
		sum[o] = 0;
		if (R > L)
		{
			sum[o] = sum[rc] + sum[lc];
			sum[o] *= setv[o];
		}
		sum[o] += addv[o] * (R - L + 1);
	}
	void Add(int o, int l, int r, int el, int er, T x)
	{
		if (el <= l&&er >= r)
		{
			addv[o] += x;
			addv[o] %= P;
		}
		else
		{
			int M = l + (r - l) / 2;
			if (M >= el)Add(o * 2, l, M, el, er, x);
			if (M < er)Add(o * 2 + 1, M + 1, r, el, er, x);
		}
		maintain(o, l, r);
	}
	void pushdown(int o, int l, int r)
	{
		if (r == l)return;
		int lc = o * 2, rc = o * 2 + 1;
		setv[lc] *= setv[o];
		setv[lc] %= P;
		setv[rc] *= setv[o];
		setv[rc] %= P;
		setv[o] = 1;
		if (addv[o])
		{
			addv[lc] += addv[o];
			addv[lc] %= P;
			addv[rc] += addv[o];
			addv[rc] %= P;
			addv[o] = 0;
		}
		int M = (l + r) / 2;
		maintain(lc, l, M); maintain(rc, M + 1, r);
	}
	void Set(int o, int l, int r, int el, int er, T v)
	{
		if (el <= l && r <= er)
		{
			setv[o] *= v;
			setv[o] %= P;
			addv[o] *= v;
			addv[o] %= P;
			maintain(o, l, r);
			return;
		}
		else pushdown(o, l, r);
		int lc = o * 2, rc = o * 2 + 1, M = l + (r - l) / 2;
		if (M >= el)Set(lc, l, M, el, er, v);
		if (M < er)Set(rc, M + 1, r, el, er, v);
		maintain(o, l, r);
	}
	SegmentResult<T> segmentresult = { 0 };
	void Query(int o, int l, int r, int el, int er, T add, T set)
	{
		if (el <= l&&er >= r)
		{
			segmentresult.sum += sum[o] * set % P + add*(r - l + 1) % P;
			segmentresult.sum %= P;
			return;
		}
		int M = l + (r - l) / 2;
		if (M >= el)Query(o * 2, l, M, el, er, add + addv[o], set*setv[o]);
		if (M < er)Query(o * 2 + 1, M + 1, r, el, er, add + addv[o], set*setv[o]);
	}
	bool SETED = 0;
public:
	/*请确保重载了对int型的乘法和加法(本函数使用快读)*/
	void scan(){ readA(); }
	void init(){ SETED = 1; InitSum(); }
	void add(int l, int r, T x){ if (!SETED)init(); Add(1, 1, n, l, r, x); }
	void set(int l, int r, T v){ if (!SETED)init(); Set(1, 1, n, l, r, v); }
	SegmentResult<T> query(int l, int r){ if (!SETED)init(); segmentresult.sum = 0; Query(1, 1, n, l, r, 0, 1); return segmentresult; }
	T& operator[](int index){ return (T&)a[index]; }
};
int main()
{
	int n;
	SegmentTree<unsigned long long> t(n = read());
	P = read();
	t.scan();
	int T = read();
	while (T--)
	{/*
		for (int i = 1; i <= n; i++)
		{
			printf("%d ", t.query(i, i).sum);
		}
		puts("");*/
		int i = read(), l = read(), r = read();
		if (i == 1)
		{
			int q = read();
			t.set(l, r, q);
		}
		else if (i == 2)
		{
			int q = read();
			t.add(l, r, q);
		}
		else
		{
			printf("%lld\n", t.query(l, r).sum%P);
		}
	}
	getchar();
	return 0;
}

T+WA

2020/6/6 21:27
加载中...