调了半天了,WA #6-#10
应该是sze维护出错了,但找不到哪错了
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MaxN 1000005
int lt[MaxN], rt[MaxN], fa[MaxN], val[MaxN], sze[MaxN];
int root = 0;
int tot = 0;
inline bool Wrt(int x)
{
return rt[fa[x]] == x;
}
inline void Rot(int x)
{
int y = fa[x], z = fa[y];
int b = (x == lt[y]) ? rt[x] : lt[x];
fa[x] = z;
fa[y] = x;
if (b)
fa[b] = y;
if (z)
{
if (lt[z] == y)
lt[z] = x;
else
rt[z] = x;
}
if (lt[y] == x)
{
rt[x] = y;
lt[y] = b;
}
else
{
lt[x] = y;
rt[y] = b;
}
sze[y] = sze[lt[y]] + sze[rt[y]] + 1;
sze[x] = sze[lt[x]] + sze[rt[x]] + 1;
return;
}
inline void Splay(int x, const int &tar)
{
for (int f; fa[x] ^ tar; Rot(x))
{
if (fa[f = fa[x]] ^ tar)
Rot(Wrt(x) ^ Wrt(f) ? x : f);
}
if (!tar)
root = x;
//printf("root=%d\n", root);
}
inline int Find(int x)
{
int u = root;
while (u)
{
//printf("u=%d val[u]=%d x=%d\n", u, val[u], x);
if (val[u] == x)
break;
else if (val[u] < x)
u = rt[u];
else
u = lt[u];
}
//printf("u=%d\n", u);
if (u)
Splay(u, 0);
return u;
}
inline int Insert(int x)
{
int u = root, f = 0, dir;
while (u)
{
f = u;
if (x < val[u])
{
u = lt[u];
dir = 0;
}
else
{
u = rt[u];
dir = 1;
}
}
int id = ++tot;
val[id] = x;
fa[id] = f;
sze[id] = 1;
if (dir)
rt[f] = id;
else
lt[f] = id;
Splay(id, 0);
}
inline void Join(int x, int y)
{
fa[x] = fa[y] = 0;
int w = x;
while (rt[w])
w = rt[w];
Splay(w, 0);
rt[w] = y;
fa[y] = w;
sze[w] = sze[lt[w]] + sze[rt[w]] + 1;
}
inline void Delete(int x)
{
int n = Find(x);
if (!n)
return;
//printf("sze[root]=%d\n", sze[root]);
Splay(n, 0);
if (!lt[n] || !rt[n]) {
fa[root = lt[n] + rt[n]] = 0;
}
else {
Join(lt[n], rt[n]);
}
lt[n] = rt[n] = 0;
sze[n] = 0;
}
inline int GetByRank(int rk)
{
int u = root;
while (u)
{
if (rk <= sze[lt[u]])
u = lt[u];
else {
rk -= sze[lt[u]] + 1;
if (!rk) {
Splay(u, 0);
return val[u];
}
u = rt[u];
}
}
}
inline int GetRank(int x)
{
int y = Find(x);
//printf("y=%d\n", y);
int rank = sze[lt[y]] + 1;
while (GetByRank(rank) > x)
--rank;
while (GetByRank(rank) < x)
++rank;
return rank;
}
inline int GetPre(int x)
{
int u = root;
int ans = -2147473647;
while (u)
{
if (val[u] >= x)
{
u = lt[u];
}
else
{
if (val[u] > ans)
ans = val[u];
u = rt[u];
}
}
return ans;
}
inline int GetBak(int x)
{
int u = root;
int ans = 2147473647;
while (u)
{
if (val[u] <= x)
{
u = rt[u];
}
else
{
if (val[u] < ans)
ans = val[u];
u = lt[u];
}
}
return ans;
}
int main()
{
memset(lt, 0, sizeof(lt));
memset(rt, 0, sizeof(rt));
memset(fa, 0, sizeof(fa));
int n;
scanf("%d", &n);
while (n--)
{
int tp, x;
scanf("%d%d", &tp, &x);
if (tp == 1)
Insert(x);
else if (tp == 2)
Delete(x);
else if (tp == 3)
printf("%d\n", GetRank(x));
else if (tp == 4)
printf("%d\n", GetByRank(x));
else if (tp == 5)
printf("%d\n", GetPre(x));
else if (tp == 6)
printf("%d\n", GetBak(x));
}
return 0;
}