我的代码是Python写的但是Pypy都TLE。不知道是Python数据的问题还是我的问题。。。
# 大根堆无旋Treap
from random import random
import sys
sys.setrecursionlimit(999999999)
class NonRotatedNode:
def __init__(self, bst, heap=None):
self.bst = bst
self.heap = heap if heap else random()
self.size = 1
self.left = None
self.right = None
class NonRotatedTreap:
@classmethod
def _merge(cls, l, r):
if not l or not r:
return l if l else r
size = l.size + r.size
if l.heap < r.heap:
new_l = NonRotatedNode(l.bst, l.heap)
new_l.left = l.left
new_l.right = cls._merge(l.right, r)
new_l.size = size
return new_l
else:
new_r = NonRotatedNode(r.bst, r.heap)
new_r.right = r.right
new_r.left = cls._merge(l, r.left)
new_r.size = size
return new_r
@classmethod
def _split(cls, x, index):
if index < x.bst:
if x.left:
new_x = NonRotatedNode(x.bst, x.heap)
new_x.right = x.right
l, r = cls._split(x.left, index)
new_x.left = r
new_x.size = x.size - (l.size if l else 0)
return l, new_x
return None, x
else:
if x.right:
new_x = NonRotatedNode(x.bst, x.heap)
new_x.left = x.left
l, r = cls._split(x.right, index)
new_x.right = l
new_x.size = x.size - (r.size if r else 0)
return new_x, r
return x, None
def __init__(self):
self.roots = []
def find(self, version, index):
if not(0 <= version < len(self.roots)):
self.roots.append(None)
return False
root = self.roots[version]
while root:
if index == root.bst:
return root.bst
elif index < root.bst:
root = root.left
else:
root = root.right
self.roots.append(root)
return False
def insert(self, version, index):
n = NonRotatedNode(index)
if not(0 <= version < len(self.roots)):
self.roots.append(n)
return
root = self.roots[version]
if not root:
self.roots.append(n)
return
l, r = self._split(root, index)
self.roots.append(self._merge(self._merge(l, n), r))
def remove(self, version, index):
if not(0 <= version < len(self.roots)):
self.roots.append(None)
return False
m, r = self._split(self.roots[version], index)
if not m:
self.roots.append(self.roots[version])
return False
l, m = self._split(m, index - 1)
if not m:
self.roots.append(self.roots[version])
return False
m = self._merge(m.left, m.right)
self.roots.append(self._merge(self._merge(l, m), r))
return True
def kth(self, version, k):
if not(0 <= version < len(self.roots)):
self.roots.append(None)
return False
x = self.roots[version]
self.roots.append(x)
if k > x.size or k < 1:
return False
while True:
left_size = x.left.size if x.left else 0
if k == left_size + 1:
return x.bst
elif k <= left_size:
x = x.left
else:
k -= 1 + left_size
x = x.right
def rank(self, version, index):
if not(0 <= version < len(self.roots)):
self.roots.append(None)
return 1
root = self.roots[version]
self.roots.append(root)
l, _ = self._split(root, index - 1)
rnk = 1 + (l.size if l else 0)
return rnk
def prev_index(self, version, index):
if not(0 <= version < len(self.roots)):
self.roots.append(None)
return False
root = self.roots[version]
self.roots.append(root)
l, _ = self._split(self.roots[version], index - 1)
if not l:
return False
x = l
while x.right:
x = x.right
return x.bst
def next_index(self, version, index):
if not(0 <= version < len(self.roots)):
self.roots.append(None)
return False
root = self.roots[version]
self.roots.append(root)
_, r = self._split(self.roots[version], index)
if not r:
return False
x = r
while x.left:
x = x.left
return x.bst
def oj():
def get_input():
input()
inp = []
while True:
try:
a, b, c = input().strip().split(' ')
except:
break
inp.append((int(a)-1, int(b), int(c)))
return inp
inp = get_input()
nrt = NonRotatedTreap()
for i in inp:
ver, order, num = i
if order == 1:
nrt.insert(ver, num)
elif order == 2:
nrt.remove(ver, num)
elif order == 3:
print(nrt.rank(ver, num))
elif order == 4:
print(nrt.kth(ver, num))
elif order == 5:
print(nrt.prev_index(ver, num))
elif order == 6:
print(nrt.next_index(ver, num))
oj()