臭代码如下,因为想练习一下class类,所以写成了这个亚子
#include<cstdio>
const int INF=2147483647;
int n;
class BSTnode {
private:
int val,sum,size;
BSTnode *lc,*rc,*parent;
friend class BSTree;
public:
BSTnode(int siz=0,int s=0,int k=0,BSTnode *ls=NULL,BSTnode *rs=NULL,BSTnode *par=NULL):
size(siz),sum(s),val(k),lc(ls),rc(rs),parent(par) {};
};
class BSTree {
private://内部接口
BSTnode *rt;
BSTnode* Max(BSTnode *x) {
while(x->rc!=NULL)x=x->rc;
return x;
}
BSTnode* Min(BSTnode *x) {
while(x->lc!=NULL)x=x->lc;
return x;
}
BSTnode* Search(BSTnode *x,int key) {
while(x!=NULL&&x->val!=key) {
if(x->val>key)x=x->lc;
else x=x->rc;
}
return x;
}
void Insert(BSTnode *&root,BSTnode *k) {
BSTnode *x=root;
BSTnode *y=NULL;
while(x!=NULL) {
x->size++;
y=x;
if(x->val<k->val) {
x=x->rc;
}
else if(x->val==k->val) {
x->sum++;
return;
}
else {
x=x->lc;
}
}
k->parent=y;
if(y==NULL) {
root=k;
} else if(y->val<k->val) {
y->rc=k;
} else y->lc=k;
}
int Find1(BSTnode *x,int key) {
if(x==NULL)return 0;
if(x->val==key)return (x->lc->size)+1;
else if(x->val>key)return Find1(x->lc,key);
else return Find1(x->rc,key)+(x->lc->size)+(x->sum);
}
int Find2(BSTnode *x,int key) {
if(x->sum+x->lc->size>=key)return x->val;
else if(x->val>key)return Find2(x->lc,key);
else return Find2(x->rc,key-(x->sum)-(x->lc->size));
}
BSTnode* Pre(BSTnode *x) {
BSTnode *y=NULL;
if(x->lc!=NULL)return max(x->lc);
y=x->parent;
while(y!=NULL&&x==y->lc) {
x=y;
y=y->parent;
}
return y;
}
BSTnode* Next(BSTnode *x) {
BSTnode *y=NULL;
if(x->rc!=NULL)return min(x->rc);
y=x->parent;
while(y!=NULL&&x==y->rc) {
x=y;
y=y->parent;
}
return y;
}
/* void InOrder(BSTnode *x){
if(x==NULL)return;
InOrder(x->lc);
printf("%d ",x->val);
InOrder(x->rc);
}*/
public://外部接口
BSTree():rt(NULL) {};
int find1(int key) {
return Find1(rt,key);
}
int find2(int key) {
return Find2(rt,key);
}
int pre(BSTnode *key) {
BSTnode* tmp=Pre(key);
if(tmp!=NULL)return tmp->val;
return -INF;
}
int next(BSTnode *key) {
BSTnode* tmp=Next(key);
if(tmp!=NULL)return tmp->val;
return INF;
}
BSTnode* search(int key) {
return Search(rt,key);
}
BSTnode* max(BSTnode *x) {
return Max(x);
}
BSTnode* min(BSTnode *x) {
return Min(x);
}
void insert(int key) {
BSTnode* tmp=new BSTnode(0,0,key,NULL,NULL,NULL);
Insert(rt,tmp);
}
/* void inOrder(){
InOrder(rt);
printf("\n");
}*/
};
int main() {
scanf("%d",&n);
BSTree *tree=new BSTree();
int opt,k;
for(int i=0;i<n;i++) {
scanf("%d%d",&opt,&k);
if(opt==1) {
printf("%d\n",tree->find1(k));
}
else if(opt==2) {
printf("%d\n",tree->find2(k));
}
else if(opt==3) {
printf("%d\n",tree->pre(tree->search(k)));
}
else if(opt==4) {
printf("%d\n",tree->next(tree->search(k)));
}
else if(opt==5) {
tree->insert(k);
}
//tree->inOrder();
}
return 0;
}