刚学Treap,WA11个点,求助
查看原帖
刚学Treap,WA11个点,求助
305532
mango09楼主2021/7/13 11:50

感觉没什么问题,对着书看了一遍,没看出来,求大佬帮忙!

#include <iostream>
#include <cstdio>
#include <cstdlib>
#define re register
#define Google namespace
#define Cookies std
using Google Cookies;

inline int read()
{
	re int x = 0, f = 0;
	re char c = getchar();
	while (c < '0' || c > '9')
	{
		f |= c == '-';
		c = getchar();
	}
	while (c >= '0' && c <= '9')
	{
		x = (x << 3) + (x << 1) + (c ^ 48);
		c = getchar();
	}
	return f ? -x : x;
}

const int MAXN = 1e5 + 5;
const int INF = 0x7fffffff;

struct treap
{
	int l, r, val, dat, cnt, siz;
}t[MAXN];

int rt, tot;

int add(int val)
{
	t[++tot] = treap{0, 0, val, rand(), 1, 1};
	return tot;
}

void pushup(int pos)
{
	t[pos].siz = t[t[pos].l].siz + t[t[pos].r].siz + t[pos].cnt;
}

void build()
{
	add(-INF), add(INF);
	rt = 1;
	t[1].r = 2;
	pushup(rt);
}
 
void zig(int &pos)
{
	int lson = t[pos].l;
	t[pos].l = t[lson].r, t[lson].r = pos;
	pos = lson;
}

void zag(int &pos)
{
	int rson = t[pos].r;
	t[pos].r = t[rson].l, t[rson].l = pos;
	pos = rson;
}

void insert(int &pos, int x)
{
	if (!pos)
	{
		pos = add(x);
		return;
	}
	if (x == t[pos].val)
	{
		t[pos].cnt++;
		pushup(pos);
		return;
	}
	else if (x < t[pos].val)
	{
		insert(t[pos].l, x);
		if (t[pos].dat < t[t[pos].l].dat) //不满足,右旋 
		{
			zig(pos);
		}
	}
	else
	{
		insert(t[pos].r, x);
		if (t[pos].dat < t[t[pos].r].dat)
		{
			zag(pos);
		}
	}
	pushup(pos);
}

void remove(int &pos, int x)
{
	if (!pos)
	{
		return;
	}
	if (x == t[pos].val)
	{
		if (t[pos].cnt >= 2)
		{
			t[pos].cnt--; //减少副本数 
			pushup(pos);
			return;
		}
		else if (t[pos].l || t[pos].r) //向下旋转 
		{
			if (!t[pos].r || t[t[pos].l].dat > t[t[pos].r].dat)
			{
				zig(pos);
				remove(t[pos].r, x);
			}
			else
			{
				zag(pos);
				remove(t[pos].l, x);
			}
			pushup(pos);
		}
		else
		{
			pos = 0; //叶节点删除 
		}
		return;
	}
	else if (x < t[pos].val)
	{
		remove(t[pos].l, x);
	}
	else
	{
		remove(t[pos].r, x);
	}
	pushup(pos);
}

int getrank(int pos, int x)
{
	if (!pos)
	{
		return 0;
	}
	if (x == t[pos].val)
	{
		return t[t[pos].l].siz + 1;
	}
	else if (x < t[pos].val)
	{
		return getrank(t[pos].l, x);
	}
	else
	{
		return getrank(t[pos].r, x) + t[t[pos].l].siz + t[pos].cnt;
	}
}

int getval(int pos, int x)
{
	if (!pos)
	{
		return INF;
	}
	if (x <= t[t[pos].l].siz)
	{
		return getval(t[pos].l, x);
	}
	else if (x <= t[t[pos].l].siz + t[pos].cnt)
	{
		return t[pos].val;
	}
	else
	{
		return getval(t[pos].r, x - t[t[pos].l].siz - t[pos].cnt);
	}
}

int pre(int x)
{
	int ans = 1, pos = rt;
	while (pos)
	{
		if (x == t[pos].val)
		{
			if (t[pos].l)
			{
				pos = t[pos].l;
				while (t[pos].r)
				{
					pos = t[pos].r; //左子树往上走 
				}
				ans = pos;
			}
			break;
		}
		if (x > t[pos].val && t[pos].val > t[ans].val)
		{
			ans = pos;
		}
		if (x < t[pos].val)
		{
			pos = t[pos].l;
		}
		else
		{
			pos = t[pos].r;
		}
	}
	return t[ans].val;
}

int nxt(int x)
{
	int ans = 2, pos = rt;
	while (pos)
	{
		if (x == t[pos].val)
		{
			if (t[pos].r)
			{
				pos = t[pos].r;
				while (t[pos].l)
				{
					pos = t[pos].l;
				}
				ans = pos;
			}
			break;
		}
		if (x < t[pos].val && t[pos].val < t[ans].val)
		{
			ans = pos;
		}
		if (x < t[pos].val)
		{
			pos = t[pos].l;
		}
		else
		{
			pos = t[pos].r;
		}
	}
	return t[ans].val;
}

int main()
{
	build();
	int n = read();
	while (n--)
	{
		int op = read(), x = read();
		switch (op)
		{
			case 1:
				insert(rt, x);
				break;
			case 2:
				remove(rt, x);
				break;
			case 3:
				printf("%d\n", getrank(rt, x) - 1);
				break;
			case 4:
				printf("%d\n", getval(rt, x + 1));
				break;
			case 5:
				printf("%d\n", pre(x));
				break;
			case 6:
				printf("%d\n", nxt(x));
				break;
		}
	}
	return 0;
}
2021/7/13 11:50
加载中...