splay 全 WA 求助
查看原帖
splay 全 WA 求助
118196
zimujun楼主2021/2/17 16:32

RT

平衡树板子 A 了,修改了输入规则之后交到这里全 WA

#include "bits/stdc++.h"

namespace Basic {
  template <typename Temp> inline void read(Temp & res) {
    Temp fh = 1; res = 0; char ch = getchar();
    for(; !isdigit(ch); ch = getchar()) if(ch == '-') fh = -1;
    for(; isdigit(ch); ch = getchar()) res = (res << 3) + (res << 1) + (ch ^ '0');
    res = res * fh;
  }
}
using namespace Basic;
using namespace std;

namespace Splay {
  #define ls a[t].ch[0]
  #define rs a[t].ch[1]
  #define root a[0].ch[1]
  const int Maxn = 1e5 + 5;
  int ncnt;

  struct Node {
    int data, cnt, sum, ch[2], fa; 
  } a[Maxn];
  
  inline int getid(int t) {
    return a[a[t].fa].ch[1] == t;
  }
  
  inline void update(int t) {
    a[t].sum = a[ls].sum + a[rs].sum + a[t].cnt;
  }
  
  inline void rotate(int t) {
    int f1 = a[t].fa, f2 = a[f1].fa, k1 = getid(t), k2 = getid(f1);
    a[f2].ch[k2] = t; a[t].fa = f2;
    a[f1].ch[k1] = a[t].ch[k1 ^ 1]; a[a[t].ch[k1 ^ 1]].fa = f1;
    a[t].ch[k1 ^ 1] = f1, a[f1].fa = t;
    update(f1); update(t);
  }
  
  inline void splay(int t, int to) {
    int f1 = a[t].fa, f2 = a[f1].fa, k1 = getid(t), k2 = getid(f1);
    while(f1 ^ to) {
      if(f2 == to) rotate(t);
      else if(k1 ^ k2) {
        rotate(t); rotate(t);
      } else {
        rotate(f1); rotate(t);
      }
      f1 = a[t].fa, f2 = a[f1].fa, k1 = getid(t), k2 = getid(f1);
    }
  }
  
  inline int find(int x) {
    int t = root;
    while(t) {
      if(x == a[t].data) return t;
      else if(x < a[t].data) t = a[t].ch[0];
      else t = a[t].ch[1];
    }
    return 0;
  }
  
  inline void Newnode(int x, int f, int id) {
    ncnt++;
    a[ncnt].data = x; a[ncnt].cnt = a[ncnt].sum = 1;
    a[ncnt].fa = f; a[ncnt].ch[0] = a[ncnt].ch[1] = 0;
    a[f].ch[id] = ncnt;
    splay(ncnt, 0);
  }
  
  inline void Insert(int x) {
    if(ncnt == 0) {
      Newnode(x, 0, 1);
      return;
    }
    int t = root;
    while(t) {
      a[t].sum++;
      if(x == a[t].data) {
        a[t].cnt++;
        return;
      } else if(x < a[t].data) {
        if(a[t].ch[0]) t = a[t].ch[0];
        else {
          Newnode(x, t, 0);
          return;
        }
      } else {
        if(a[t].ch[1]) t = a[t].ch[1];
        else {
          Newnode(x, t, 1);
          return;
        }
      }
    }
  }
  
  inline int upper(int x) {
    int t = find(x); splay(t, 0);
    if(!a[t].ch[1]) return 0;
    t = a[t].ch[1];
    while(a[t].ch[0]) t = a[t].ch[0];
    return t;
  }
  inline int lower(int x) {
    int t = find(x); splay(t, 0);
    if(!a[t].ch[0]) return 0;
    t = a[t].ch[0];
    while(a[t].ch[1]) t = a[t].ch[1];
    return t;
  }
  
  inline void Delete(int x) {
    int t = find(x);
    splay(t, 0);
    a[t].cnt--;
    if(!a[t].cnt) {
      int lid = lower(x), uid = upper(x);
      if(lid) {
        splay(lid, t);
        int L = a[t].ch[0], R = a[t].ch[1];
        a[R].fa = L; a[L].ch[1] = R;
        a[L].fa = 0; a[0].ch[1] = L;
        update(L);
      } else if(uid) {
        splay(uid, t);
        int L = a[t].ch[0], R = a[t].ch[1];
        a[L].fa = R; a[R].ch[0] = L;
        a[R].fa = 0; a[0].ch[1] = R;
        update(R);
      } else ncnt = 0;
    }
  }
  
  inline int find_lower(int x) {
    int t = find(x);
    if(t) {
      int res = lower(x);
      return a[res].data;
    } else {
      Insert(x);
      int res = lower(x);
      if(res == 0) return -0x7fffffff; 
      Delete(x);
      splay(res, 0);
      return a[res].data;
    }
  }
  inline int find_upper(int x) {
    int t = find(x);
    if(t) {
      int res = upper(x);
      return a[res].data;
    } else {
      Insert(x);
      int res = upper(x);
      if(res == 0) return 0x7fffffff;
      Delete(x);
      splay(res, 0);
      return a[res].data;  
    }
  }
  
  inline int kth(int k) {
    int t = a[0].ch[1];
    while(t) {
      if(a[a[t].ch[0]].sum + 1 <= k && k <= a[a[t].ch[0]].sum + a[t].cnt) {
        splay(t, 0);
        return a[t].data;
      } else if(k <= a[a[t].ch[0]].sum) {
        t = a[t].ch[0];
      } else if(k > a[a[t].ch[0]].sum + a[t].cnt) {
        k -= a[a[t].ch[0]].sum + a[t].cnt;
        t = a[t].ch[1];
      }
    }
    return 0;
  }
  
  inline int Rank(int x) {
    int t = find(x);
    if(!t) return 0;
    splay(t, 0);
    return a[a[t].ch[0]].sum + 1;
  }
  #undef ls
  #undef rs
  #undef root
}

using namespace Splay;

int n, m, o, x, y;

int main() {
  read(m);
  while(m--) {
    read(o);
    if(o == 5) {
      read(x); Insert(x);
    } else if(o == 666) {
      read(x); Delete(x);
    } else if(o == 1) {
      read(x); printf("%d\n", Rank(x));
    } else if(o == 2) {
      read(x); printf("%d\n", kth(x));
    } else if(o == 4) {
      read(x); printf("%d\n", find_upper(x));
    } else if(o == 3){
      read(x); printf("%d\n", find_lower(x));
    }
  }
}
2021/2/17 16:32
加载中...