ORDERSET - Order statistic set

Tác giả: ladpro98

Ngôn ngữ: C++

#include <bits/stdc++.h>

using namespace std;

class Treap {
    struct Node {
        int key, prior, size;
        Node *l, *r;

        Node(int key): key(key), prior(rand()), size(1), l(NULL), r(NULL) {}
        ~Node() { delete l; delete r; }
    };

    int size(Node *x) { return x ? x->size : 0; }

    void update(Node *x) {
        if (!x) return;
        x->size = size(x->l) + size(x->r) + 1;
    }

    Node* join(Node *l, Node *r) {
        if (!l || !r) return l ? l : r;
        if (l->prior < r->prior)
            return l->r = join(l->r, r), update(l), l;
        else
            return r->l = join(l, r->l), update(r), r;
    }

    void split(Node *v, int x, Node* &l, Node* &r) {
        if (!v)
            l = r = NULL;
        else if (v->key < x)
            split(v->r, x, v->r, r), l = v;
        else
            split(v->l, x, l, v->l), r = v;
        update(v);
    }

    int getKeyByOrder(Node *v, int k) {
        int cnt = size(v->l);
        if (k <= cnt) return getKeyByOrder(v->l, k);
        if (k == cnt + 1) return v->key;
        return getKeyByOrder(v->r, k - cnt - 1);
    }

    void show(Node *x) {
        if (!x) return;
        show(x->l);
        cout << x->key << ' ';
        show(x->r);
    }

    Node *root;

public:
    Treap(): root(NULL) {}
    ~Treap() { delete root; }

    bool insert(int x) {
        Node *l, *mid, *r;
        split(root, x, l, mid);
        split(mid, x + 1, mid, r);
        if (mid) {
            root = join(join(l, mid), r);
            return false;
        }
        root = join(join(l, new Node(x)), r);
        return true;
    }

    bool erase(int x) {
        Node *l, *mid, *r;
        split(root, x, l, mid);
        split(mid, x + 1, mid, r);
        root = join(l, r);
        if (mid) {
            delete mid;
            return true;
        }
        return false;
    }

    int getKeyByOrder(int x) { return getKeyByOrder(root, x); }

    int countSmaller(int x) {
        Node *l, *r;
        split(root, x, l, r);
        int res = size(l);
        root = join(l, r);
        return res;
    }

    int size() const { return root ? root->size : 0; }

    void show() {
        cout << "{";
        show(root);
        cout << "}\n";
    }
};

int main() {
    ios::sync_with_stdio(false); cin.tie(NULL);
    Treap S;
    int nQuery; cin >> nQuery;
    while (nQuery--) {
        char cmd; int x;
        cin >> cmd >> x;
        if (cmd == 'I')
            S.insert(x);
        else if (cmd == 'D')
            S.erase(x);
        else if (cmd == 'C')
            cout << S.countSmaller(x) << '\n';
        else if (S.size() < x)
            cout << "invalid\n";
        else 
            cout << S.getKeyByOrder(x) << '\n';
    }
    return 0;
}

Download