Skip to content

Weight Balanced Leafy Tree

简介

重量平衡,并且数值只存在叶子上

与替罪羊树类似有一个平衡常数 \(\alpha\)

定义一个节点平衡指数是左子树重量除以整颗子树的重量

如果平衡指数小于 \(\alpha\),说明不平衡,要调整

这里调整指的是“旋转”, 而不是重构子树

说是旋转,其实是切割并拼接

一次旋转如果要向左转,就将左儿子和右儿子的左儿子合并成新的左儿子,右儿子的右儿子作为新的右儿子,原本的右儿子直接丢弃

这里丢弃节点要回收,不然空间不对

但是,一次旋转后可能还是不平衡,这是由于左转时新的左儿子可能过大导致的

这进一步又是由于右儿子的左儿子过大导致的(旋转会将右儿子的左儿子与左儿子合并)

所以要特判。如果右儿子平衡指数小于 \(\frac{1-2\alpha}{1-\alpha}\) 就单旋,否则双旋(右儿子右旋,自己再左旋)

右儿子右旋可以减小右儿子的左儿子大小

右旋是对称的,同理

数组版本:

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10, INF=1e9+10;
const double alpha = 0.25;
int T,n,i;
struct WBLT{
    struct node{
        int left;
        int right;
        int size;
        int value;
        node(int size, int value) : size(size), value(value) {left = 0, right = 0;}
        node(int left, int right, int size, int value) : 
            left(left), right(right), size(size), value(value) {} 
        node(){}
    }a[N];
    int cnt;
    int newnode(int x, int y, int size, int value){
        a[++cnt] = node(x, y, size, value);
        return cnt;
    }
    int root;
    WBLT(){root = 0, cnt = 0;}
    bool leaf(int x){
        return (!a[x].left && !a[x].right); 
    }
    void pushup(int x){
        if(leaf(x)) return;
        a[x].value = a[a[x].right].value;
        a[x].size = a[a[x].left].size + a[a[x].right].size;
    }
    //void build();
    void copynode(int x, int y){
        a[x] = a[y];
        /*x->left = y->left;
        x->right = y->right;
        x->size = y->size;
        x->value = y->value;*/
    }
    int merge(int x, int y){
        int z = newnode(x, y, 0, a[x].value);
        pushup(z);
        return z;
    }
    void rotate(int x, int t){
        if(!t){
            a[x].left = merge(a[x].left, a[a[x].right].left);
            int newright = a[a[x].right].right;
            //delete(a[x].right);
            a[x].right = newright;
        }else{
            a[x].right = merge(a[a[x].left].right, a[x].right);
            int newleft = a[a[x].left].left;
            //delete(x->left);
            a[x].left = newleft;
        }
    }
    void maintain(int x){
        if(leaf(x)) return;
        if(a[a[x].left].size <= a[a[x].right].size){
            if(a[a[x].left].size >= a[x].size * alpha) return;
            if(a[a[a[x].right].left].size >= a[a[x].right].size * (1 - 2 * alpha) / (1 - alpha)){
                rotate(a[x].right, 1);
            }
            rotate(x, 0);
        }else{
            if(a[a[x].right].size >= a[x].size * alpha) return;
            if(a[a[a[x].left].right].size >= a[a[x].left].size * (1 - 2 * alpha) / (1 - alpha)){
                rotate(a[x].left, 0);
            }
            rotate(x, 1);
        }
    }
    void Insert(int x, int value){


        if(!root){
            root = newnode(0, 0, 1, value);
            return;
        }
        if(leaf(x)){
            a[x].left = newnode(0, 0, 1, min(value, a[x].value)); // value
            a[x].right = newnode(0, 0, 1, max(value, a[x].value)); // x->value
            pushup(x);
            maintain(x);
            return; 
        }
        if(a[a[x].left].value >= value) Insert(a[x].left, value);
        else Insert(a[x].right, value);
        pushup(x);
        maintain(x);

    }
    void Delete(int x, int fa, int value){
        if(leaf(x)){
            int left = a[fa].left;
            int right = a[fa].right;
            if(a[fa].left == x) copynode(fa, a[fa].right);
            else copynode(fa, a[fa].left);
            //delete(left), delete(right);
            pushup(fa);
            maintain(fa);
            return;
        }
        if(a[a[x].left].value >= value) Delete(a[x].left, x, value);
        else Delete(a[x].right, x, value);
        pushup(x);
        maintain(x);

    }
    int rank(int x, int value){
        if(leaf(x)) return 1;
        if(a[a[x].left].value >= value) return rank(a[x].left, value);
        else return rank(a[x].right, value) + a[a[x].left].size;
    }
    int kth(int x, int rank){
        //assert(x != 0);
        if(a[x].size == rank) return a[x].value;
        if(a[a[x].left].size >= rank) return kth(a[x].left, rank);
        else return kth(a[x].right, rank - a[a[x].left].size);
    }
    int merges(int x, int y){
        if(!x) return y;
        if(!y) return x;
        if(min(a[x].size, a[y].size) >= (a[x].size + a[y].size) * alpha) return merge(x, y);
        if(a[x].size >= a[y].size){
            if(a[a[x].left].size >= (a[x].size + a[y].size) * alpha){
                return merges(a[x].left, merges(a[x].right, y));
            }else return merges(merges(a[x].left, a[a[x].right].left), merges(a[a[x].right].right, y));
        }else{
            if(a[a[y].right].size >= (a[x].size + a[y].size) * alpha){
                return merges(merges(x, a[y].left), a[y].right);
            }else return merges(merges(x, a[a[y].left].left), merges(a[a[y].left].right, a[y].right));
        }
    }
    void split(int p, int k, int &x, int &y){
        if(!k) return x = 0, y = p, void();
        if(leaf(p)) return x = p, y = 0, void();
        if(k <= a[a[p].left].size){
            split(a[p].left, k, x, y);
            y = merges(y, a[p].right);
        }else{
            split(a[p].right, k - a[a[p].left].size, x, y);
            x = merges(a[p].left, x);
        }
    }
    void traverse(int x){
        if(!x) return;
        cerr<<a[x].value<<" "<<a[x].size<<":(";
        if(a[x].left) cerr<<a[a[x].left].value;
        else cerr<<"NULL";
        cerr<< ",";
        if(a[x].right) cerr<<a[a[x].right].value;
        else cerr<<"NULL";
        cerr<<")"<<endl;
        traverse(a[x].left);
        traverse(a[x].right);

    }
};
WBLT S;
void solve() {
    scanf("%d",&n);
    S.root = S.cnt = 0;
    //cerr<<"OK";
    S.Insert(S.root, INF), S.Insert(S.root, -INF);
    for(i=1;i<=n;++i){
        int op,x;
        scanf("%d%d",&op,&x);
        if(op==1){
            S.Insert(S.root, x);
        }else if(op==2){
            S.Delete(S.root, 0, x);
        }else if(op==3){
           printf("%d\n",S.rank(S.root, x)-1);
        }else if(op==4){
            printf("%d\n",S.kth(S.root, x + 1));
        }else if(op==5){
            printf("%d\n",S.kth(S.root, S.rank(S.root, x) - 1));
        }else if(op==6){
            printf("%d\n",S.kth(S.root, S.rank(S.root, x + 1)));   
        }
    }
}
int main() {
    //freopen("P3369_8.in","r",stdin);
    //freopen("P3369_8.ans","w",stdout);
    solve();
    return 0;
}

指针版本:

#include<bits/stdc++.h>
using namespace std;
const int N=2e5+10, INF=1e9+10;
const double alpha = 0.25;
int T,n,i;
template<typename T> class WBLT;
template<typename T> class node{
public:
    node<T>* left;
    node<T>* right;
    int size;
    T value;
    node<T>(int size, T value) : size(size), value(value) {left = NULL, right = NULL;}
    node<T>(node<T>* left, node<T>* right, int size, T value) : 
        left(left), right(right), size(size), value(value) {}
    friend class WBLT<T>; 
};
template<typename T> class WBLT{
public:
    node<T>* root;
    WBLT<T>(){root = NULL;}
    bool leaf(node<T>* x){
        return (!x->left && !x->right); 
    }
    void pushup(node<T>* x){
        if(leaf(x)) return;
        x->value = x->right->value;
        x->size = x->left->size + x->right->size;
    }
    void copynode(node<T>* x, node<T>* y){
        *x = *y;
    }
    node<T>* merge(node<T>* x, node<T>* y){
        node<T>* z = new node<T>(x, y, 0, x->value);
        pushup(z);
        return z;
    }
    void rotate(node<T>* x, int t){
        if(!t){
            x->left = merge(x->left, x->right->left);
            node<T>* newright = x->right->right;
            delete(x->right);
            x->right = newright;
        }else{
            x->right = merge(x->left->right, x->right);
            node<T>* newleft = x->left->left;
            delete(x->left);
            x->left = newleft;
        }
    }
    void maintain(node<T>* x){
        if(leaf(x)) return;
        if(x->left->size <= x->right->size){
            if(x->left->size >= x->size * alpha) return;
            if(x->right->left->size >= x->right->size * (1 - 2 * alpha) / (1 - alpha)){
                rotate(x->right, 1);
            }
            rotate(x, 0);
        }else{
            if(x->right->size >= x->size * alpha) return;
            if(x->left->right->size >= x->left->size * (1 - 2 * alpha) / (1 - alpha)){
                rotate(x->left, 0);
            }
            rotate(x, 1);
        }
    }
    void Insert(node<T>* x, T value){
        if(!root){
            root = new node<T>(1, value);
            return;
        }
        if(leaf(x)){
            x->left = new node<T>(1, min(value, x->value)); // value
            x->right = new node<T>(1, max(value, x->value)); // x->value
            pushup(x);
            maintain(x);
            return; 
        }
        if(x->left->value >= value) Insert(x->left, value);
        else Insert(x->right, value);
        pushup(x);
        maintain(x);
    }
    void Delete(node<T>* x, node<T>* fa, T value){
        if(leaf(x)){
            node<T>* left = fa->left;
            node<T>* right = fa->right;
            if(fa->left == x) copynode(fa, fa->right);
            else copynode(fa, fa->left);
            delete(left), delete(right);
            pushup(fa);
            maintain(fa);
            return;
        }
        if(x->left->value >= value) Delete(x->left, x, value);
        else Delete(x->right, x, value);
        pushup(x);
        maintain(x);
    }
    int rank(node<T>* x, T value){
        if(leaf(x)) return 1;
        if(x->left->value >= value) return rank(x->left, value);
        else return rank(x->right, value) + x->left->size;
    }
    T kth(node<T>* x, int rank){
        assert(x != NULL);
        if(x->size == rank) return x->value;
        if(x->left->size >= rank) return kth(x->left, rank);
        else return kth(x->right, rank - x->left->size);
    }
    node<T>* merges(node<T>* x, node<T>* y){
        if(!x) return y;
        if(!y) return x;
        if(min(x->size, y->size) >= (x->size + y->size) * alpha) return merge(x, y);
        if(x->size >= y->size){
            if(x->left->size >= (x->size + y->size) * alpha){
                return merges(x->left, merges(x->right, y));
            }else return merges(merges(x->left, x->right->left), merges(x->right->right, y));
        }else{
            if(y->right->size >= (x->size + y->size) * alpha){
                return merges(merges(x, y->left), y->right);
            }else return merges(merges(x, y->left->left), merges(y->left->right, y->right));
        }
    }
    void split(node<T>* p, int k, node<T>* x, node<T>* y){
        if(!k) return x = NULL, y = p, void();
        if(leaf(p)) return x = p, y = NULL, void();
        if(k <= p->left->size){
            split(p->left, k, x, y);
            y = merges(y, p->right);
        }else{
            split(p->right, k - p->left->size, x, y);
            x = merges(p->left, x);
        }
    }
    void traverse(node<T>* x){
        if(!x) return;
        cerr<<x->value<<" "<<x->size<<":(";
        if(x->left) cerr<<x->left->value;
        else cerr<<"NULL";
        cerr<< ",";
        if(x->right) cerr<<x->right->value;
        else cerr<<"NULL";
        cerr<<")"<<endl;
        traverse(x->left);
        traverse(x->right);
    }
};
void solve() {
    scanf("%d",&n);
    WBLT<int>* S = new WBLT<int>();
    S->Insert(S->root, INF), S->Insert(S->root, -INF);
    for(i=1;i<=n;++i){
        int op,x;
        scanf("%d%d",&op,&x);
        if(op==1){
            S->Insert(S->root, x);
        }else if(op==2){
            S->Delete(S->root, NULL, x);
        }else if(op==3){
           printf("%d\n",S->rank(S->root, x)-1);
        }else if(op==4){
            printf("%d\n",S->kth(S->root, x + 1));
        }else if(op==5){
            printf("%d\n",S->kth(S->root, S->rank(S->root, x) - 1));
        }else if(op==6){
            printf("%d\n",S->kth(S->root, S->rank(S->root, x + 1)));   
        }
    }
}
int main() {
    solve();
    return 0;
}

有关合并与分裂

~~用指针写这两个就是史~~

~~用这个树写文艺平衡树就是史,建议转去写Splay~~