用可持久化的无旋Treap不知道为啥卡了4个TLE
查看原帖
用可持久化的无旋Treap不知道为啥卡了4个TLE
281217
jon666楼主2020/11/10 16:58

我的代码是Python写的但是Pypy都TLE。不知道是Python数据的问题还是我的问题。。。

# 大根堆无旋Treap

from random import random
import sys
sys.setrecursionlimit(999999999)

class NonRotatedNode:
    def __init__(self, bst, heap=None):
        self.bst = bst
        self.heap = heap if heap else random()
        self.size = 1
        self.left = None
        self.right = None


class NonRotatedTreap:

    @classmethod
    def _merge(cls, l, r):
        if not l or not r:
            return l if l else r
        size = l.size + r.size
        if l.heap < r.heap:
            new_l = NonRotatedNode(l.bst, l.heap)
            new_l.left = l.left
            new_l.right = cls._merge(l.right, r)
            new_l.size = size
            return new_l
        else:
            new_r = NonRotatedNode(r.bst, r.heap)
            new_r.right = r.right
            new_r.left = cls._merge(l, r.left)
            new_r.size = size
            return new_r

    @classmethod
    def _split(cls, x, index):
        if index < x.bst:
            if x.left:
                new_x = NonRotatedNode(x.bst, x.heap)
                new_x.right = x.right
                l, r = cls._split(x.left, index)
                new_x.left = r
                new_x.size = x.size - (l.size if l else 0)
                return l, new_x
            return None, x
        else:
            if x.right:
                new_x = NonRotatedNode(x.bst, x.heap)
                new_x.left = x.left
                l, r = cls._split(x.right, index)
                new_x.right = l
                new_x.size = x.size - (r.size if r else 0)
                return new_x, r
            return x, None

    def __init__(self):
        self.roots = []

    def find(self, version, index):
        if not(0 <= version < len(self.roots)):
            self.roots.append(None)
            return False
        root = self.roots[version]
        while root:
            if index == root.bst:
                return root.bst
            elif index < root.bst:
                root = root.left
            else:
                root = root.right
        self.roots.append(root)
        return False

    def insert(self, version, index):
        n = NonRotatedNode(index)
        if not(0 <= version < len(self.roots)):
            self.roots.append(n)
            return
        root = self.roots[version]
        if not root:
            self.roots.append(n)
            return
        l, r = self._split(root, index)
        self.roots.append(self._merge(self._merge(l, n), r))

    def remove(self, version, index):
        if not(0 <= version < len(self.roots)):
            self.roots.append(None)
            return False
        m, r = self._split(self.roots[version], index)
        if not m:
            self.roots.append(self.roots[version])
            return False
        l, m = self._split(m, index - 1)
        if not m:
            self.roots.append(self.roots[version])
            return False
        m = self._merge(m.left, m.right)
        self.roots.append(self._merge(self._merge(l, m), r))
        return True

    def kth(self, version, k):
        if not(0 <= version < len(self.roots)):
            self.roots.append(None)
            return False
        x = self.roots[version]
        self.roots.append(x)
        if k > x.size or k < 1:
            return False
        while True:
            left_size = x.left.size if x.left else 0
            if k == left_size + 1:
                return x.bst
            elif k <= left_size:
                x = x.left
            else:
                k -= 1 + left_size
                x = x.right

    def rank(self, version, index):
        if not(0 <= version < len(self.roots)):
            self.roots.append(None)
            return 1
        root = self.roots[version]
        self.roots.append(root)
        l, _ = self._split(root, index - 1)
        rnk = 1 + (l.size if l else 0)
        return rnk

    def prev_index(self, version, index):
        if not(0 <= version < len(self.roots)):
            self.roots.append(None)
            return False
        root = self.roots[version]
        self.roots.append(root)
        l, _ = self._split(self.roots[version], index - 1)
        if not l:
            return False
        x = l
        while x.right:
            x = x.right
        return x.bst

    def next_index(self, version, index):
        if not(0 <= version < len(self.roots)):
            self.roots.append(None)
            return False
        root = self.roots[version]
        self.roots.append(root)
        _, r = self._split(self.roots[version], index)
        if not r:
            return False
        x = r
        while x.left:
            x = x.left
        return x.bst


def oj():
    
    def get_input():
        input()
        inp = []
        while True:
            try:
                a, b, c = input().strip().split(' ')
            except:
                break
            inp.append((int(a)-1, int(b), int(c)))
        return inp

    inp = get_input()
    nrt = NonRotatedTreap()

    for i in inp:
        ver, order, num = i
        if order == 1:
            nrt.insert(ver, num)
        elif order == 2:
            nrt.remove(ver, num)
        elif order == 3:
            print(nrt.rank(ver, num))
        elif order == 4:
            print(nrt.kth(ver, num))
        elif order == 5:
            print(nrt.prev_index(ver, num))
        elif order == 6:
            print(nrt.next_index(ver, num))


oj()

2020/11/10 16:58
加载中...