ORDERSET - Order statistic set
Tác giả: skyvn97
Ngôn ngữ: C++
#include<cstdio>
#include<cassert>
#define FOR(i,a,b) for (int i=(a);i<=(b);i=i+1)
#define REP(i,n) for (int i=0;i<(n);i=i+1)
const char inv[]="invalid";
struct node {
int val,nnode;
node *parent,*left,*right;
node(){}
node(const int &x) {
val=x;
nnode=1;
parent=NULL;
left=NULL;
right=NULL;
}
void calculate(void) {
int l,r;
if (left==NULL) l=0; else l=left->nnode;
if (right==NULL) r=0; else r=right->nnode;
nnode=l+r+1;
}
bool isleft(const node *a) const {
return (a!=NULL && left==a);
}
bool isright(const node *a) const {
return (a!=NULL && right==a);
}
};
node *root;
void create(node *&a,const int &x) {
if (a!=NULL) return;
a=new node(x);
}
void link(node *a,node *b,int dir) {
if (a==NULL) {
root=b;
if (root!=NULL) root->parent=NULL;
return;
}
if (dir==1) a->left=b; else a->right=b;
if (b!=NULL) b->parent=a;
}
void treeview(node *a,int lev) {
if (a==NULL) return;
REP(i,lev) printf(" ");
printf("%d|%d\n",a->val,a->nnode);
REP(i,lev) printf(" "); printf("left\n");
treeview(a->left,lev+1);
REP(i,lev) printf(" "); printf("right\n");
treeview(a->right,lev+1);
}
void left_rotation(node *a) {
node *b,*c;
int dir=0;
c=a->parent;
if (c!=NULL) {
if (c->isleft(a)) dir=1; else dir=2;
}
b=a->right;
link(a,b->left,2);
link(b,a,1);
link(c,b,dir);
}
void right_rotation(node *a) {
node *b,*c;
int dir=0;
c=a->parent;
if (c!=NULL) {
if (c->isleft(a)) dir=1; else dir=2;
}
b=a->left;
link(a,b->right,1);
link(b,a,2);
link(c,b,dir);
}
void splay(node *a) {
while (a->parent!=NULL) {
if (a->parent->parent==NULL) {
if (a->parent->isleft(a)) {
right_rotation(a->parent);
a->right->calculate();
}
else {
left_rotation(a->parent);
a->left->calculate();
}
}
else {
if (a->parent->isleft(a)) {
if (a->parent->parent->isleft(a->parent)) {
right_rotation(a->parent->parent);
right_rotation(a->parent);
a->right->right->calculate();
a->right->calculate();
}
else {
right_rotation(a->parent);
left_rotation(a->parent);
a->left->calculate();
a->right->calculate();
}
}
else {
if (a->parent->parent->isright(a->parent)) {
left_rotation(a->parent->parent);
left_rotation(a->parent);
a->left->left->calculate();
a->left->calculate();
}
else {
left_rotation(a->parent);
right_rotation(a->parent);
a->right->calculate();
a->left->calculate();
}
}
}
a->calculate();
}
}
void find(const int &x) {
if (root==NULL) return;
node *p=root;
while (true) {
if (p->val==x) break;
if (p->val>x) {
if (p->left!=NULL) p=p->left;
else break;
}
else {
if (p->right!=NULL) p=p->right;
else break;
}
}
splay(p);
}
void split(const int &x,node *&a) {
if (root==NULL) {
a=NULL;
return;
}
find(x);
a=root->right;
if (a!=NULL) a->parent=NULL;
root->right=NULL;
root->calculate();
}
void merge(node *a) {
if (root==NULL) {
link(NULL,a,1);
return;
}
node *p=root;
while (p->right!=NULL) p=p->right;
splay(p);
link(p,a,2);
p->calculate();
}
void insert(const int &x) {
//printf("Insert %d\n",x);
node *p,*q;
int dir=0;
q=NULL;
p=root;
while (p!=NULL) {
if (p->val==x) return;
q=p;
if (p->val>x) {
p=p->left;
dir=1;
}
else {
p=p->right;
dir=2;
}
}
create(p,x);
link(q,p,dir);
splay(p);
}
void erase(const int &x) {
//printf("Erase %d\n",x);
node *p,*l;
split(x,p);
if (root==NULL || root->val!=x) {
merge(p);
return;
}
l=root->left;
delete root;
link(NULL,l,1);
merge(p);
}
void value(const int &k) {
if (root==NULL || root->nnode<k) {
printf("%s\n",inv);
return;
}
int rem=0;
int nleft;
node *p=root;
while (true) {
if (p->left==NULL) nleft=0; else nleft=p->left->nnode;
if (rem+nleft+1==k) {
printf("%d\n",p->val);
splay(p);
return;
}
if (rem+nleft+1>k) p=p->left;
else {
p=p->right;
rem+=nleft+1;
}
}
}
void count(const int &k) {
if (root==NULL) {
printf("0\n");
return;
}
int res=0;
int nleft;
node *p=root;
while (true) {
if (p->left==NULL) nleft=0; else nleft=p->left->nnode;
if (p->val<k) {
res+=nleft+1;
if (p->right!=NULL) p=p->right;
else break;
}
else {
if (p->left!=NULL) p=p->left;
else break;
}
}
printf("%d\n",res);
splay(p);
}
void answer(void) {
root=NULL;
char type;
int v,q;
scanf("%d",&q);
REP(i,q) {
scanf(" %c",&type);
scanf("%d",&v);
//fprintf(stdout,"Query %d: %c %d\n",i+1,type,v);
if (type=='I') insert(v);
if (type=='D') erase(v);
if (type=='K') value(v);
if (type=='C') count(v);
//int nn=0;
//if (root!=NULL) nn=root->nnode;
//fprintf(stderr,"%d\n",nn);
//treeview(root,0);
}
}
int main(void) {
answer();
return 0;
}