Skip to content

线段树合并

update 2023.10.18 ~~爷复活了,回来写总结了~~

模板P4556 雨天的尾巴

要求求出每个节点的出现次数最多的补给。

可以想到,给每个节点开一个\([1,10^5]\)的桶来储存每种补给的出现次数。

然后将每次操作转化为差分:1-u,1-v路径+1,1-lca,1-fa[lca]路径-1.

最后从下到上dfs,将数组加起来,即\(cnt[u][i]=\sum_{v\in son[u]} cnt[v][i]\),就可以得到答案。

但是,\(cnt\)数组太大,没法存储。

故我们有了线段树合并。

每个节点动态开点,维护一个\([1,10^5]\)的权值线段树。

可以知道,每次操作会让4个点分别增加\(O(logn)\)个点。

然后,对于合并操作,我们分别遍历两个线段树。

如果两个线段树都拥有左子节点,则继续遍历两个左儿子;否则,就让第一个线段树的左儿子指向这个儿子;如果到达叶子节点,就将两个叶子的值相加;右儿子同理。这样就可以完成合并了。

可以证明,每次合并时线段树2中遍历到的节点被合并到了新线段树(线段树1)中,则新线段树中这些属于2的点不会再被遍历到;而没遍历到的点被接在了线段树1上,组成了一个船新的、没有遍历过的线段树。所以每个点最多被遍历一次。

而每次合并不会增加新节点,所以空间与时间复杂度均为\(O(nlogn)\)

至于最大值,我们对于每个线段树节点维护一个种类\(tre[i]\)和出现次数\(sum[i]\)即可。

#include<iostream>
#include<cstring>
#define mid (l+r>>1)
using namespace std;
const int N=1e5+10,K=30,NN=1e5;

struct edge{
    int v,nxt;
}e[N<<1];
int cnt,u,v,z,n,m,tot;
int head[N],fa[N][K],dep[N],sum[N<<6],tre[N<<6],ls[N<<6],rs[N<<6],rt[N],ans[N];
void add(int u,int v){
    e[++cnt].v=v,e[cnt].nxt=head[u],head[u]=cnt;
}
void pushup(int x){
    if(sum[ls[x]]>=sum[rs[x]]){
        sum[x]=sum[ls[x]];
        tre[x]=tre[ls[x]];
    }else{
        sum[x]=sum[rs[x]];
        tre[x]=tre[rs[x]];
    }
}
void change(int &x,int l,int r,int pos,int k){
    if(!x) x=++tot;
    if(l==r){
        sum[x]+=k;
        tre[x]=pos;
        return;
    }
    if(pos<=mid) change(ls[x],l,mid,pos,k);
    else change(rs[x],mid+1,r,pos,k);
    pushup(x);
}
int merge(int a,int b,int l,int r){
    if(!a || !b) return a+b;
    if(l==r){
        sum[a]+=sum[b];
        return a;
    }
    ls[a]=merge(ls[a],ls[b],l,mid);
    rs[a]=merge(rs[a],rs[b],mid+1,r);
    pushup(a);
    return a;
}

void dfs(int u,int p){
    fa[u][0]=p;
    for(int i=1;i<K;++i){
        fa[u][i]=fa[fa[u][i-1]][i-1];
    }
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p) continue;
        dep[v]=dep[u]+1;
        dfs(v,u);
    }
}
int getlca(int u,int v){
    if(u==v) return u;
    if(dep[u]<dep[v]) swap(u,v);
    int k=dep[u]-dep[v];
    for(int i=0;i<K;++i){
        if(k&(1<<i)) u=fa[u][i];
    }
    if(u==v) return u;
    for(int i=K-1;i>=0;--i){
        if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];   
    }
    return fa[u][0];
}
void DFS(int u,int p){
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p) continue;
        DFS(v,u);
        rt[u]=merge(rt[u],rt[v],1,NN);
    }
    if(sum[rt[u]]) ans[u]=tre[rt[u]];
}
int main(){
    memset(head,-1,sizeof head);
    cnt=-1;
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;++i){
        scanf("%d%d",&u,&v);
        add(u,v);
        add(v,u);
    }
    dfs(1,0);
    for(int i=1;i<=m;++i){
        scanf("%d%d%d",&u,&v,&z);
        int lca=getlca(u,v);
        change(rt[u],1,NN,z,1);
        change(rt[v],1,NN,z,1);
        change(rt[lca],1,NN,z,-1);
        change(rt[fa[lca][0]],1,NN,z,-1);
    }
    DFS(1,0);
    for(int i=1;i<=n;++i) printf("%d\n",ans[i]);
    return 0;
} 
/*
10 10 
2 1 
3 2 
4 3 
5 3 
6 3 
7 4 
8 5 
9 8 
10 3 
6 2 6 
3 2 6 
3 3 2 
6 6 6 
10 3 3 
10 7 1 
3 7 4 
7 9 5 
4 2 4 
3 3 6
*/

P5298

\(f_{i,j}\)为第i个节点出现j的概率,l,r为左右儿子,m为i子树中叶子节点的个数。则:

\(f_{i,j}=f_{l,j} \times (p_i\times\sum_{k=1}^{j-1}f_{r,k}+(1-p_i)\times \sum_{k=j+1}^{m}f_{r,k})+f_{r,j} \times (p_i\times\sum_{k=1}^{j-1}f_{l,k}+(1-p_i)\times \sum_{k=j+1}^{m}f_{l,k})\)

这个\(f\)数组就是要维护的桶。

那么,我们可以在每个节点上,对于权值的总个数开一个权值线段树,维护每种权值的出现概率。

由于概率是分数,那么可以用\(ax\equiv b(\mod p)\)来求出\(x=\frac{b}{a}\)在模意义下的值。

这个式子可以用扩展欧几里得(exgcd)求,也可以根据小费马定理:\(a^{p-1}\equiv 1(\mod p)\)得出等价关系\(a^{p-2}\equiv a^{-1}(\mod p)\)

考虑左边,那么对于括号里的那一坨,可以在合并过程中维护。

假设\(xmul\)是那一坨的值,那么每一搜索左节点,我们就可以更新\((1-p_i)\times \sum_{k=j+1}^{m}f_{r,k}\)的值,因为即将搜索的左节点的权值一定小于全部右节点的权值。搜索右节点同理。

这样,我们在最后对线段树进行一个dfs,记录每个叶子节点的值,即可得到最终答案。

#include<iostream>
#include<algorithm>
#include<cstring>
#define mid (l+r>>1)
using namespace std;
const int N=3e5+10,P=998244353,PP=796898467;
int tre[N<<5],multag[N<<5],ls[N<<5],rs[N<<5],rt[N],ch[N][2],val[N],b[N],cnt[N],ans[N];
int n,fa,tot,id;
int build(){
    int x=++id;
    ls[x]=rs[x]=tre[x]=0,multag[x]=1;
    return x;
}
void pushup(int x){
    tre[x]=(tre[ls[x]]+tre[rs[x]])%P;
}
void cal(int x,int v){
    tre[x]=(1ll*tre[x]*v)%P;
    multag[x]=(1ll*multag[x]*v)%P;
}
void pushdown(int x){
    int k=multag[x];
    if(ls[x]) cal(ls[x],k);
    if(rs[x]) cal(rs[x],k);
    multag[x]=1;
}
void change(int &x,int l,int r,int pos,int k){

    if(!x) x=build();
    if(l==r){
        tre[x]+=k;
        return;
    }
    pushdown(x);
    if(pos<=mid) change(ls[x],l,mid,pos,k);
    else change(rs[x],mid+1,r,pos,k);
    pushup(x);
}
int merge(int x,int y,int l,int r,int xmul,int ymul,int p){
    //if(!x && !y) return 0;
    if(!x){
        cal(y,ymul);
        return y;
    }
    if(!y){
        cal(x,xmul);
        return x;
    }
    pushdown(x),pushdown(y);
    int lsx=tre[ls[x]],lsy=tre[ls[y]],rsx=tre[rs[x]],rsy=tre[rs[y]];
    ls[x]=merge(ls[x],ls[y],l,mid,(xmul+(1ll*rsy%P*(1-p+P)%P))%P,(ymul+(1ll*rsx%P*(1-p+P)%P))%P,p);
    rs[x]=merge(rs[x],rs[y],mid+1,r,(xmul+(1ll*lsy*p%P))%P,(ymul+(1ll*lsx*p%P))%P,p);
    pushup(x);
    return x;
}

void dfs(int u){
    if(!cnt[u]) change(rt[u],1,tot,val[u],1);
    else if(cnt[u]==1) dfs(ch[u][0]),rt[u]=rt[ch[u][0]];
    else if(cnt[u]==2) dfs(ch[u][0]),dfs(ch[u][1]),rt[u]=merge(rt[ch[u][0]],rt[ch[u][1]],1,tot,0,0,val[u]); 
}
void work(int x,int l,int r){
    if(l==r){
        ans[l]=tre[x];
        return;
    }
    pushdown(x);
    work(ls[x],l,mid);
    work(rs[x],mid+1,r);
}

int main(){
    id=tot=0;
    scanf("%d",&n);
    for(int i=1;i<=n;++i){
        scanf("%d",&fa);
        if(fa) ch[fa][cnt[fa]++]=i;
    }
    for(int i=1;i<=n;++i) scanf("%d",&val[i]);
    for(int i=1;i<=n;++i){
        if(cnt[i]){
            val[i]=(1ll*val[i]*PP)%P;
        }else{
            b[++tot]=val[i];
        }
    }
    sort(b+1,b+tot+1);
    for(int i=1;i<=n;++i) if(!cnt[i]) val[i]=lower_bound(b+1,b+tot+1,val[i])-b;
    dfs(1);
    work(rt[1],1,tot);
    int res=0;
    for(int i=1;i<=tot;++i) res=(res+1ll*ans[i]%P*ans[i]%P*i%P*b[i]%P)%P;
    printf("%d\n",res);
    return 0;
}