萌新求助,无旋Treap树套树90分
查看原帖
萌新求助,无旋Treap树套树90分
310525
老莽莽穿一切楼主2021/8/18 14:02
#include<bits/stdc++.h>
using namespace std;

const int N = 5e4 + 10, maxv = 1e8, inf = 0x7fffffff;

struct Treap {
	int tr, hp, siz, l, r;

	Treap(int tr = 0, int hp = rand() * rand(), int siz = 1, int l = 0, int r = 0)
		: tr(tr), hp(hp), siz(siz), l(l), r(r) {}
};

int T, n, tot;
int a[N], tr[N << 2];
Treap trp[N << 6];

void split(int, int&, int&, int), merge(int&, int, int), ins(int&, int), del(int&, int);
int newNode(int), kth(int, int), getLess(int, int), getFront(int, int), getBack(int, int);

#define Segment int, int, int

void build(Segment), mdf(Segment, int, int);
int getLess(Segment, int, int, int), getFront(Segment, int, int, int), getBack(Segment, int, int, int);

#undef Segment

#define Root 1, 1, n

inline int getKth(int l, int r, int k) {
	int L = 1, R = maxv;
	while (L < R) {
		int mid = L + R + 1 >> 1;
		if (getLess(Root, l, r, mid) < k) {
			L = mid;
		} else {
			R = mid - 1;
		}
	}
	return L;
}

int main() {
	trp[tot].siz = 0;
	scanf("%d%d", &n, &T);
	for (int i = 1; i <= n; ++ i) {
		scanf("%d", &a[i]);
	}
	build(Root);
	while (T --) {
		int op, x, y, z; scanf("%d", &op);
		switch (op) {
		case 1 :
			scanf("%d%d%d", &x, &y, &z);
			printf("%d\n", getLess(Root, x, y, z) + 1);
			break;
		case 2 :
			scanf("%d%d%d", &x, &y, &z);
			printf("%d\n", getKth(x, y, z));
			break;
		case 3 :
			scanf("%d%d", &x, &y);
			mdf(Root, x, y);
			a[x] = y;
			break;
		case 4 :
			scanf("%d%d%d", &x, &y, &z);
			printf("%d\n", getFront(Root, x, y, z));
			break;
		case 5 :
			scanf("%d%d%d", &x, &y, &z);
			printf("%d\n", getBack(Root, x, y, z));
			break;
		}
	}
}

#define ls trp[p].l
#define rs trp[p].r

inline int newNode(int v) {
	trp[++ tot] = Treap(v);
	return tot;
}
inline void split(int p, int&x, int&y, int v) {
	if (!p) {
		x = y = 0;
		return;
	}
	if (trp[p].tr <= v) {
		x = p; split(rs, rs, y, v);
	} else {
		y = p; split(ls, x, ls, v);
	}
	trp[p].siz = trp[ls].siz + trp[rs].siz + 1;
}
inline void merge(int&p, int x, int y) {
	if (!x || !y) {
		p = x | y;
		return;
	}
	if (trp[x].hp <= trp[y].hp) {
		p = x; merge(rs, rs, y);
	} else {
		p = y; merge(ls, x, ls);
	}
	trp[p].siz = trp[ls].siz + trp[rs].siz + 1;
}
inline int kth(int p, int k) {
	if (k <= trp[ls].siz) {
		return kth(ls, k);
	}
	k -= trp[ls].siz + 1;
	return k ? kth(rs, k) : trp[p].tr;
}
inline void ins(int&p, int v) {
	int x;
	split(p, p, x, v);
	merge(p, p, newNode(v));
	merge(p, p, x);
}
inline void del(int&p, int v) {
	int x, y;
	split(p, p, y, v); split(p, x, p, v - 1);
	merge(p, ls, rs);
	merge(p, x, p); merge(p, p, y);
}
inline int getLess(int p, int v) {
	int x; split(p, x, p, v - 1);
	int res = trp[x].siz;
	merge(p, x, p);
	return res;
}
inline int getFront(int p, int v) {
	int x; split(p, x, p, v - 1);
	if (!x) return -inf;
	int res = kth(x, trp[x].siz);
	merge(p, x, p);
	return res;
}
inline int getBack(int p, int v) {
	int x; split(p, p, x, v);
	if (!x) return inf;
	int res = kth(x, 1);
	merge(p, p, x);
	return res;
}

#undef ls
#undef rs

#define Segment int p, int L, int R
#define ls (p << 1)
#define rs (ls | 1)
#define mid (L + R >> 1)
#define Ls ls, L, mid
#define Rs rs, mid + 1, R

inline void build(Segment) {
	for (int i = L; i <= R; ++ i) {
		ins(tr[p], a[i]);
	}

	if (L == R) return;
	build(Ls); build(Rs);
}
inline void mdf(Segment, int pos, int v) {
	del(tr[p], a[pos]);
	ins(tr[p], v);

	if (L == R) return;
	if (pos <= mid) {
		mdf(Ls, pos, v);
	} else {
		mdf(Rs, pos, v);
	}
}
inline int getLess(Segment, int l, int r, int v) {
	if (L >= l && R <= r) {
		return getLess(tr[p], v);
	}
	
	int res = 0;
	if (l <= mid) {
		res += getLess(Ls, l, r, v);
	}
	if (r > mid) {
		res += getLess(Rs, l, r, v);
	}
	return res;
}
inline int getFront(Segment, int l, int r, int v) {
	if (L >= l && R <= r) {
		return getFront(tr[p], v);
	}
	
	int res = -inf;
	if (l <= mid) {
		res = max(res, getFront(Ls, l, r, v));
	}
	if (r > mid) {
		res = max(res, getFront(Rs, l, r, v));
	}
	return res;
}
inline int getBack(Segment, int l, int r, int v) {
	if (L >= l && R <= r) {
		return getBack(tr[p], v);
	}
	
	int res = inf;
	if (l <= mid) {
		res = min(res, getBack(Ls, l, r, v));
	}
	if (r > mid) {
		res = min(res, getBack(Rs, l, r, v));
	}
	return res;
}

#undef Segment
#undef ls
#undef rs
#undef mid
#undef Ls
#undef Rs
2021/8/18 14:02
加载中...