替罪羊树MLE求解
查看原帖
替罪羊树MLE求解
238500
yaorz楼主2021/10/16 13:50

替罪羊树做法,MLE了,想问一下为什么会MLE

评测结果

#include <bits/stdc++.h>
using namespace std;
#define N 100010
const double alpha = 0.75;
struct Node{
	int ls, rs;
	int size, val, tim, res;
}a[N];
int n, id, x, rt[N], tot, cnt, root;
void update(int t) {
	a[t].size = a[a[t].ls].size + a[a[t].rs].size + a[t].tim;
	a[t].res = a[a[t].ls].res + a[a[t].rs].res + a[t].tim;
}
void unfold(int t) {
	if(t == 0) return;
	unfold(a[t].ls);
	if(a[t].tim) rt[++tot] = t;
	unfold(a[t].rs);
}
int rebuild(int l, int r) {
	if(l == r) return 0;
	int mid = (l + r) >> 1;
	a[rt[mid]].ls = rebuild(1, mid);
	a[rt[mid]].rs = rebuild(mid + 1, r);
	update(rt[mid]);
	return rt[mid];
}
void balance(int &t) {
	tot = 0;
	unfold(t);
	t = rebuild(1, tot + 1);
}
bool judge(int t) {
	double p = max(a[a[t].ls].size, a[a[t].rs].size);
	double q = alpha * a[t].size;
	if(p > q) return true;
	if(double(a[t].res) < alpha * double(a[t].size)) return true;
	return false;
}
void del(int &t, int x) {
	if(t == 0) return;
	a[t].res--;
	if(a[t].val == x) a[t].tim--;
	else {
		if(x < a[t].val) del(a[t].ls, x);
		else del(a[t].rs, x);
	}
	update(t);
	if(judge(t)) balance(t);
}
void insert(int &t, int x) {
	if(t == 0) {
		t = ++cnt;
		if(!root) root = 1;
		a[t].val = x;
		a[t].ls = a[t].rs = 0;
		a[t].tim = a[t].size = a[t].res = 1;
	} else {
		if(a[t].val == x) a[t].tim++;
		else if(x < a[t].val) insert(a[t].ls, x);
		else insert(a[t].rs, x);
		update(t);
		if(judge(t)) balance(t);
	}
}
int rank_down(int t, int x) {
	if(t == 0) return 0;
	if(a[t].tim != 0 && a[t].val == x) {
		return a[a[t].ls].res;
	} else if(x < a[t].val) {
		return rank_down(a[t].ls, x);
	} else {
		return a[a[t].ls].res + a[t].tim + rank_down(a[t].rs, x);
	}
}
int rank_up(int t, int x) {
	if(t == 0) return 1;
	if(a[t].tim != 0 && a[t].val == x) {
		return 1 + a[t].tim + a[a[t].ls].res;
	} else if(x < a[t].val) {
		return rank_up(a[t].ls, x);
	} else {
		return a[a[t].ls].res + a[t].tim + rank_up(a[t].rs, x);
	}
}
int rank(int t, int x) {
	if(a[t].ls == a[t].rs) return a[t].val;
	if(x <= a[a[t].ls].res) return rank(a[t].ls, x);
	else if(x > a[a[t].ls].res && a[a[t].ls].res + a[t].tim >= x) {
		return a[t].val;
	} else {
		return rank(a[t].rs, x - a[a[t].ls].res - a[t].tim);
	}
}
int main() {
	cin >> n;
	while(n--) {
		scanf("%d%d", &id, &x);
		if(id == 1) {
			insert(root, x);
		} else if(id == 2) del(root, x);
		else if(id == 3) printf("%d\n", rank_down(root, x) + 1);
		else if(id == 4) printf("%d\n", rank(root, x));
		else if(id == 5) printf("%d\n", rank(root, rank_down(root, x)));
		else printf("%d\n", rank(root, rank_up(root, x)));
	}
	return 0;
}
2021/10/16 13:50
加载中...