ORDERSET - Order statistic set

Tác giả: skyvn97

Ngôn ngữ: C++

#include<cstdio>
#include<cassert>
#define FOR(i,a,b) for (int i=(a);i<=(b);i=i+1)
#define REP(i,n) for (int i=0;i<(n);i=i+1)
const char inv[]="invalid";
struct node {
	int val,nnode;
	node *parent,*left,*right;
	node(){}	
	node(const int &x) {
		val=x;
		nnode=1;
		parent=NULL;
		left=NULL;
		right=NULL;
	}
	void calculate(void) {
		int l,r;
		if (left==NULL) l=0; else l=left->nnode;
		if (right==NULL) r=0; else r=right->nnode;
		nnode=l+r+1;
	}
	bool isleft(const node *a) const {
		return (a!=NULL && left==a);
	}
	bool isright(const node *a) const {
		return (a!=NULL && right==a);
	}
};
node *root;
void create(node *&a,const int &x) {
	if (a!=NULL) return;
	a=new node(x);
}
void link(node *a,node *b,int dir) {
	if (a==NULL) {
		root=b;
		if (root!=NULL) root->parent=NULL;
		return;
	}
	if (dir==1) a->left=b; else a->right=b;
	if (b!=NULL) b->parent=a;
}
void treeview(node *a,int lev) {
	if (a==NULL) return;
	REP(i,lev) printf("  ");
	printf("%d|%d\n",a->val,a->nnode);
	REP(i,lev) printf("  "); printf("left\n");
	treeview(a->left,lev+1);
	REP(i,lev) printf("  "); printf("right\n");
	treeview(a->right,lev+1);
}
void left_rotation(node *a) {
	node *b,*c;
	int dir=0;
	c=a->parent;
	if (c!=NULL) {
		if (c->isleft(a)) dir=1; else dir=2;
	}
	b=a->right;
	link(a,b->left,2);
	link(b,a,1);
	link(c,b,dir);
}
void right_rotation(node *a) {
	node *b,*c;
	int dir=0;
	c=a->parent;
	if (c!=NULL) {
		if (c->isleft(a)) dir=1; else dir=2;
	}
	b=a->left;
	link(a,b->right,1);
	link(b,a,2);
	link(c,b,dir);
}
void splay(node *a) {
	while (a->parent!=NULL) {
		if (a->parent->parent==NULL) {
			if (a->parent->isleft(a)) {
				right_rotation(a->parent);
				a->right->calculate();				
			}
			else {
				left_rotation(a->parent);
				a->left->calculate();
			}
		}
		else {
			if (a->parent->isleft(a)) {
				if (a->parent->parent->isleft(a->parent)) {
					right_rotation(a->parent->parent);
					right_rotation(a->parent);
					a->right->right->calculate();
					a->right->calculate();
				}
				else {
					right_rotation(a->parent);
					left_rotation(a->parent);
					a->left->calculate();
					a->right->calculate();
				}
			}
			else {
				if (a->parent->parent->isright(a->parent)) {
					left_rotation(a->parent->parent);
					left_rotation(a->parent);
					a->left->left->calculate();
					a->left->calculate();
				}
				else {
					left_rotation(a->parent);
					right_rotation(a->parent);
					a->right->calculate();
					a->left->calculate();
				}
			}
		}
		a->calculate();
	}
}
void find(const int &x) {
	if (root==NULL) return;
	node *p=root;
	while (true) {
		if (p->val==x) break;
		if (p->val>x) {
			if (p->left!=NULL) p=p->left;
			else break;
		}
		else {
			if (p->right!=NULL) p=p->right;
			else break;
		}
	}
	splay(p);
}
void split(const int &x,node *&a) {
	if (root==NULL) {
		a=NULL;
		return;
	}
	find(x);
	a=root->right;
	if (a!=NULL) a->parent=NULL;
	root->right=NULL;
	root->calculate();
}
void merge(node *a) {
	if (root==NULL) {
		link(NULL,a,1);
		return;
	}
	node *p=root;
	while (p->right!=NULL) p=p->right;
	splay(p);
	link(p,a,2);
	p->calculate();
}
void insert(const int &x) {
	//printf("Insert %d\n",x);
	node *p,*q;
	int dir=0;
	q=NULL;
	p=root;
	while (p!=NULL) {
		if (p->val==x) return;
		q=p;
		if (p->val>x) {
			p=p->left;
			dir=1;
		}
		else {
			p=p->right;
			dir=2;
		}
	}
	create(p,x);
	link(q,p,dir);
	splay(p);
}
void erase(const int &x) {
	//printf("Erase %d\n",x);
	node *p,*l;
	split(x,p);
	if (root==NULL || root->val!=x) {
		merge(p);
		return;
	}
	l=root->left;
	delete root;
	link(NULL,l,1);
	merge(p);
}
void value(const int &k) {
	if (root==NULL || root->nnode<k) {
		printf("%s\n",inv);
		return;
	}
	int rem=0;
	int nleft;
	node *p=root;	
	while (true) {
		if (p->left==NULL) nleft=0; else nleft=p->left->nnode;
		if (rem+nleft+1==k) {
			printf("%d\n",p->val);
			splay(p);
			return;
		}
		if (rem+nleft+1>k) p=p->left;
		else {
			p=p->right;
			rem+=nleft+1;
		}
	}	
}
void count(const int &k) {
	if (root==NULL) {
		printf("0\n");
		return;
	}
	int res=0;
	int nleft;
	node *p=root;
	while (true) {
		if (p->left==NULL) nleft=0; else nleft=p->left->nnode;
		if (p->val<k) {
			res+=nleft+1;
			if (p->right!=NULL) p=p->right;
			else break;
		}
		else {
			if (p->left!=NULL) p=p->left;
			else break;
		}
	}
	printf("%d\n",res);
	splay(p);
}
void answer(void) {
	root=NULL;
	char type;
	int v,q;
	scanf("%d",&q);
	REP(i,q) {
		scanf(" %c",&type);
		scanf("%d",&v);
		//fprintf(stdout,"Query %d: %c %d\n",i+1,type,v);
		if (type=='I') insert(v);
		if (type=='D') erase(v);
		if (type=='K') value(v);
		if (type=='C') count(v);
		//int nn=0;
		//if (root!=NULL) nn=root->nnode;
		//fprintf(stderr,"%d\n",nn);
		//treeview(root,0);
	}	
}
int main(void) {
	answer();
	return 0;
}

Download