Splay求助
查看原帖
Splay求助
37084
Yemaster楼主2021/1/10 20:24

调了半天了,WA #6-#10

应该是sze维护出错了,但找不到哪错了

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define MaxN 1000005

int lt[MaxN], rt[MaxN], fa[MaxN], val[MaxN], sze[MaxN];
int root = 0;
int tot = 0;

inline bool Wrt(int x)
{
    return rt[fa[x]] == x;
}

inline void Rot(int x)
{
    int y = fa[x], z = fa[y];
    int b = (x == lt[y]) ? rt[x] : lt[x];
    fa[x] = z;
    fa[y] = x;
    if (b)
        fa[b] = y;
    if (z)
    {
        if (lt[z] == y)
            lt[z] = x;
        else
            rt[z] = x;
    }
    if (lt[y] == x)
    {
        rt[x] = y;
        lt[y] = b;
    }
    else
    {
        lt[x] = y;
        rt[y] = b;
    }
    sze[y] = sze[lt[y]] + sze[rt[y]] + 1;
    sze[x] = sze[lt[x]] + sze[rt[x]] + 1;
    return;
}

inline void Splay(int x, const int &tar)
{
    for (int f; fa[x] ^ tar; Rot(x))
    {
        if (fa[f = fa[x]] ^ tar)
            Rot(Wrt(x) ^ Wrt(f) ? x : f);
    }
    if (!tar)
        root = x;
    //printf("root=%d\n", root);
}

inline int Find(int x)
{
    int u = root;
    while (u)
    {
    	//printf("u=%d val[u]=%d x=%d\n", u, val[u], x);
        if (val[u] == x)
            break;
        else if (val[u] < x)
            u = rt[u];
        else
            u = lt[u];
    }
    //printf("u=%d\n", u);
    if (u)
        Splay(u, 0);
    return u;
}

inline int Insert(int x)
{
    int u = root, f = 0, dir;
    while (u)
    {
        f = u;
        if (x < val[u])
        {
            u = lt[u];
            dir = 0;
        }
        else
        {
            u = rt[u];
            dir = 1;
        }
    }
    int id = ++tot;
    val[id] = x;
    fa[id] = f;
    sze[id] = 1;
    if (dir)
        rt[f] = id;
    else
        lt[f] = id;
    Splay(id, 0);
}

inline void Join(int x, int y)
{
	fa[x] = fa[y] = 0;
    int w = x;
    while (rt[w])
	    w = rt[w];
	Splay(w, 0);
	rt[w] = y;
	fa[y] = w;
	sze[w] = sze[lt[w]] + sze[rt[w]] + 1;
}

inline void Delete(int x)
{
    int n = Find(x);
    if (!n)
    	return;
	//printf("sze[root]=%d\n", sze[root]);
    Splay(n, 0);
    if (!lt[n] || !rt[n]) {
    	fa[root = lt[n] + rt[n]] = 0;
    }
	else {
		Join(lt[n], rt[n]);	
	}
	lt[n] = rt[n] = 0;
	sze[n] = 0;
}

inline int GetByRank(int rk)
{
    int u = root;
    while (u)
    {
    	if (rk <= sze[lt[u]])
    		u = lt[u];
    	else {
    		rk -= sze[lt[u]] + 1;
    		if (!rk) {
    			Splay(u, 0);
				return val[u];	
    		}
    		u = rt[u];
    	}
    }
}

inline int GetRank(int x)
{
    int y = Find(x);
    //printf("y=%d\n", y);
    int rank = sze[lt[y]] + 1;
    while (GetByRank(rank) > x)
    	--rank;
    while (GetByRank(rank) < x)
    	++rank;
    return rank;
}

inline int GetPre(int x)
{
    int u = root;
    int ans = -2147473647;
    while (u)
    {
        if (val[u] >= x)
        {
            u = lt[u];
        }
        else
        {
            if (val[u] > ans)
                ans = val[u];
            u = rt[u];
        }
    }
    return ans;
}

inline int GetBak(int x)
{
    int u = root;
    int ans = 2147473647;
    while (u)
    {
        if (val[u] <= x)
        {
            u = rt[u];
        }
        else
        {
            if (val[u] < ans)
                ans = val[u];
            u = lt[u];
        }
    }
    return ans;
}

int main()
{
    memset(lt, 0, sizeof(lt));
    memset(rt, 0, sizeof(rt));
    memset(fa, 0, sizeof(fa));
    int n;
    scanf("%d", &n);
    while (n--)
    {
        int tp, x;
        scanf("%d%d", &tp, &x);
        if (tp == 1)
            Insert(x);
        else if (tp == 2)
            Delete(x);
        else if (tp == 3)
            printf("%d\n", GetRank(x));
        else if (tp == 4)
            printf("%d\n", GetByRank(x));
        else if (tp == 5)
            printf("%d\n", GetPre(x));
        else if (tp == 6)
            printf("%d\n", GetBak(x));
    }
    return 0;
}
2021/1/10 20:24
加载中...