玄学
查看原帖
玄学
341177
江南__楼主2020/6/19 17:42
#include <cstdio> 
#define SZ 2000009 
#define INF 2147483647 
using namespace std; 

int n, va[SZ], ch[SZ][2], fa[SZ], num[SZ], sum[SZ], cnt = 0; 
int q[SZ], sq = 0, rt; 

int read(); 
void update(int x) 
	{sum[x] = num[x] + sum[ch[x][0]] + sum[ch[x][1]]; } 
void build() {
	rt = ++cnt; va[rt] = INF; 
	ch[rt][0] = ++cnt; fa[cnt] = rt; 
	va[cnt] = -INF; 
} 
int get_pre(int u, int nva) {
	if (!u) return 0; 
	if (va[u] >= nva) return get_pre(ch[u][0], nva); 
	int ret = get_pre(ch[u][1], nva); 
	return ret ? ret : u; 
} 
int get_pst(int u, int nva) {
	if (!u) return 0; 
	if (va[u] <= nva) return get_pst(ch[u][1], nva); 
	int ret = get_pst(ch[u][0], nva); 
	return ret ? ret : u; 
} 
void rotate(int x) {
	int y = fa[x], z = fa[y]; 
	int l = ch[y][1] == x, r = l ^ 1; 
	ch[y][l] = ch[x][r]; ch[x][r] = y; 
	if (ch[y][l]) fa[ch[y][l]] = y; 
	fa[y] = x; fa[x] = z; 
	update(y), update(x); 
	if (!z) return ; 
	l = ch[z][1] == y; 
	ch[z][l] = x; 
	update(z); 
} 
void Splay(int x, int ufa) {
	while (fa[x] != ufa) {
		int y = fa[x], z = fa[y]; 
		if (z != ufa) ch[y][1] == x ^ ch[z][1] == y ? rotate(x) : rotate(y); 
		rotate(x); 
	} 
	rt = ufa ? rt : x; 
} 
void insert(int nva) {
	int y = get_pre(rt, nva + 1); 
	if (va[y] == nva) {
		Splay(y, 0); 
		++sum[y], ++num[y]; 
		return ; 
	} 
	int z = get_pst(rt, nva); 
	Splay(z, 0); 
	Splay(y, z); 
	int x = sq ? q[sq--] : ++cnt; 
	va[x] = nva; ++num[x], ++sum[x]; 
	fa[x] = y; ch[y][1] = x; 
	update(y), update(z); 
} 
void delt(int nva) {
	int y = get_pre(rt, nva); 
	int z = get_pst(rt, nva); 
	Splay(z, 0), Splay(y, z); 
	int x = ch[y][1]; 
	if (num[x] > 1) {
		--num[x], --sum[x]; 
		update(y), update(x); 
		return ; 
	} 
	ch[y][1] = 0; fa[x] = 0; 
	num[x] = sum[x] = va[x] = 0; 
	ch[x][0] = ch[x][1] = 0; 
	q[++sq] = x; 
	update(y), update(z); 
} 
int get_rk(int nva) {
	int x = get_pre(rt, nva + 1); 
	Splay(x, 0); 
	return sum[ch[x][0]] + 1; 
} 
int get_va(int u, int x) {
	if (sum[ch[u][0]] < x && (sum[ch[u][0]] + num[u]) >= x) return va[u]; 
	if (sum[ch[u][0]] < x) return get_va(ch[u][1], x - sum[ch[u][0]] - num[u]); 
	else return get_va(ch[u][0], x); 
} 


int main() {
//	freopen("a.in", "r", stdin); 
//	freopen("a.out", "w", stdout); 
	n = read(); build(); 
	while (n--) {
		int opt = read(), x = read(); 
		if (opt == 1) insert(x); 
		else if (opt == 2) delt(x); 
		else if (opt == 3) printf("%d\n", get_rk(x)); 
		else if (opt == 4) printf("%d\n", get_va(rt, x)); 
		else if (opt == 5) printf("%d\n", va[get_pre(rt, x)]); 
		else printf("%d\n", va[get_pst(rt, x)]); 
	} 
	return 0; 
} 

int read() {
	int x = 0, f = 1; char c = getchar(); 
	while (c > '9' || c < '0') f = (c == '-') ? -1 : f, c = getchar(); 
	while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); 
	return x * f; 
} 

上面这份代码是T了最后一个点的, 但是如果把81行get_rk函数里的int x = get_pre(rt, nva + 1)改成 int x = get_pst(rt, nva - 1)就过了, 这是为什么?

2020/6/19 17:42
加载中...