线段树合并
update 2023.10.18 ~~爷复活了,回来写总结了~~
模板P4556 雨天的尾巴
要求求出每个节点的出现次数最多的补给。
可以想到,给每个节点开一个\([1,10^5]\)的桶来储存每种补给的出现次数。
然后将每次操作转化为差分:1-u,1-v路径+1,1-lca,1-fa[lca]路径-1.
最后从下到上dfs,将数组加起来,即\(cnt[u][i]=\sum_{v\in son[u]} cnt[v][i]\),就可以得到答案。
但是,\(cnt\)数组太大,没法存储。
故我们有了线段树合并。
每个节点动态开点,维护一个\([1,10^5]\)的权值线段树。
可以知道,每次操作会让4个点分别增加\(O(logn)\)个点。
然后,对于合并操作,我们分别遍历两个线段树。
如果两个线段树都拥有左子节点,则继续遍历两个左儿子;否则,就让第一个线段树的左儿子指向这个儿子;如果到达叶子节点,就将两个叶子的值相加;右儿子同理。这样就可以完成合并了。
可以证明,每次合并时线段树2中遍历到的节点被合并到了新线段树(线段树1)中,则新线段树中这些属于2的点不会再被遍历到;而没遍历到的点被接在了线段树1上,组成了一个船新的、没有遍历过的线段树。所以每个点最多被遍历一次。
而每次合并不会增加新节点,所以空间与时间复杂度均为\(O(nlogn)\)。
至于最大值,我们对于每个线段树节点维护一个种类\(tre[i]\)和出现次数\(sum[i]\)即可。
#include<iostream>
#include<cstring>
#define mid (l+r>>1)
using namespace std;
const int N=1e5+10,K=30,NN=1e5;
struct edge{
int v,nxt;
}e[N<<1];
int cnt,u,v,z,n,m,tot;
int head[N],fa[N][K],dep[N],sum[N<<6],tre[N<<6],ls[N<<6],rs[N<<6],rt[N],ans[N];
void add(int u,int v){
e[++cnt].v=v,e[cnt].nxt=head[u],head[u]=cnt;
}
void pushup(int x){
if(sum[ls[x]]>=sum[rs[x]]){
sum[x]=sum[ls[x]];
tre[x]=tre[ls[x]];
}else{
sum[x]=sum[rs[x]];
tre[x]=tre[rs[x]];
}
}
void change(int &x,int l,int r,int pos,int k){
if(!x) x=++tot;
if(l==r){
sum[x]+=k;
tre[x]=pos;
return;
}
if(pos<=mid) change(ls[x],l,mid,pos,k);
else change(rs[x],mid+1,r,pos,k);
pushup(x);
}
int merge(int a,int b,int l,int r){
if(!a || !b) return a+b;
if(l==r){
sum[a]+=sum[b];
return a;
}
ls[a]=merge(ls[a],ls[b],l,mid);
rs[a]=merge(rs[a],rs[b],mid+1,r);
pushup(a);
return a;
}
void dfs(int u,int p){
fa[u][0]=p;
for(int i=1;i<K;++i){
fa[u][i]=fa[fa[u][i-1]][i-1];
}
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v;
if(v==p) continue;
dep[v]=dep[u]+1;
dfs(v,u);
}
}
int getlca(int u,int v){
if(u==v) return u;
if(dep[u]<dep[v]) swap(u,v);
int k=dep[u]-dep[v];
for(int i=0;i<K;++i){
if(k&(1<<i)) u=fa[u][i];
}
if(u==v) return u;
for(int i=K-1;i>=0;--i){
if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
}
return fa[u][0];
}
void DFS(int u,int p){
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v;
if(v==p) continue;
DFS(v,u);
rt[u]=merge(rt[u],rt[v],1,NN);
}
if(sum[rt[u]]) ans[u]=tre[rt[u]];
}
int main(){
memset(head,-1,sizeof head);
cnt=-1;
scanf("%d%d",&n,&m);
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs(1,0);
for(int i=1;i<=m;++i){
scanf("%d%d%d",&u,&v,&z);
int lca=getlca(u,v);
change(rt[u],1,NN,z,1);
change(rt[v],1,NN,z,1);
change(rt[lca],1,NN,z,-1);
change(rt[fa[lca][0]],1,NN,z,-1);
}
DFS(1,0);
for(int i=1;i<=n;++i) printf("%d\n",ans[i]);
return 0;
}
/*
10 10
2 1
3 2
4 3
5 3
6 3
7 4
8 5
9 8
10 3
6 2 6
3 2 6
3 3 2
6 6 6
10 3 3
10 7 1
3 7 4
7 9 5
4 2 4
3 3 6
*/
P5298
设\(f_{i,j}\)为第i个节点出现j的概率,l,r为左右儿子,m为i子树中叶子节点的个数。则:
\(f_{i,j}=f_{l,j} \times (p_i\times\sum_{k=1}^{j-1}f_{r,k}+(1-p_i)\times \sum_{k=j+1}^{m}f_{r,k})+f_{r,j} \times (p_i\times\sum_{k=1}^{j-1}f_{l,k}+(1-p_i)\times \sum_{k=j+1}^{m}f_{l,k})\)
这个\(f\)数组就是要维护的桶。
那么,我们可以在每个节点上,对于权值的总个数开一个权值线段树,维护每种权值的出现概率。
由于概率是分数,那么可以用\(ax\equiv b(\mod p)\)来求出\(x=\frac{b}{a}\)在模意义下的值。
这个式子可以用扩展欧几里得(exgcd)求,也可以根据小费马定理:\(a^{p-1}\equiv 1(\mod p)\)得出等价关系\(a^{p-2}\equiv a^{-1}(\mod p)\)
考虑左边,那么对于括号里的那一坨,可以在合并过程中维护。
假设\(xmul\)是那一坨的值,那么每一搜索左节点,我们就可以更新\((1-p_i)\times \sum_{k=j+1}^{m}f_{r,k}\)的值,因为即将搜索的左节点的权值一定小于全部右节点的权值。搜索右节点同理。
这样,我们在最后对线段树进行一个dfs,记录每个叶子节点的值,即可得到最终答案。
#include<iostream>
#include<algorithm>
#include<cstring>
#define mid (l+r>>1)
using namespace std;
const int N=3e5+10,P=998244353,PP=796898467;
int tre[N<<5],multag[N<<5],ls[N<<5],rs[N<<5],rt[N],ch[N][2],val[N],b[N],cnt[N],ans[N];
int n,fa,tot,id;
int build(){
int x=++id;
ls[x]=rs[x]=tre[x]=0,multag[x]=1;
return x;
}
void pushup(int x){
tre[x]=(tre[ls[x]]+tre[rs[x]])%P;
}
void cal(int x,int v){
tre[x]=(1ll*tre[x]*v)%P;
multag[x]=(1ll*multag[x]*v)%P;
}
void pushdown(int x){
int k=multag[x];
if(ls[x]) cal(ls[x],k);
if(rs[x]) cal(rs[x],k);
multag[x]=1;
}
void change(int &x,int l,int r,int pos,int k){
if(!x) x=build();
if(l==r){
tre[x]+=k;
return;
}
pushdown(x);
if(pos<=mid) change(ls[x],l,mid,pos,k);
else change(rs[x],mid+1,r,pos,k);
pushup(x);
}
int merge(int x,int y,int l,int r,int xmul,int ymul,int p){
//if(!x && !y) return 0;
if(!x){
cal(y,ymul);
return y;
}
if(!y){
cal(x,xmul);
return x;
}
pushdown(x),pushdown(y);
int lsx=tre[ls[x]],lsy=tre[ls[y]],rsx=tre[rs[x]],rsy=tre[rs[y]];
ls[x]=merge(ls[x],ls[y],l,mid,(xmul+(1ll*rsy%P*(1-p+P)%P))%P,(ymul+(1ll*rsx%P*(1-p+P)%P))%P,p);
rs[x]=merge(rs[x],rs[y],mid+1,r,(xmul+(1ll*lsy*p%P))%P,(ymul+(1ll*lsx*p%P))%P,p);
pushup(x);
return x;
}
void dfs(int u){
if(!cnt[u]) change(rt[u],1,tot,val[u],1);
else if(cnt[u]==1) dfs(ch[u][0]),rt[u]=rt[ch[u][0]];
else if(cnt[u]==2) dfs(ch[u][0]),dfs(ch[u][1]),rt[u]=merge(rt[ch[u][0]],rt[ch[u][1]],1,tot,0,0,val[u]);
}
void work(int x,int l,int r){
if(l==r){
ans[l]=tre[x];
return;
}
pushdown(x);
work(ls[x],l,mid);
work(rs[x],mid+1,r);
}
int main(){
id=tot=0;
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&fa);
if(fa) ch[fa][cnt[fa]++]=i;
}
for(int i=1;i<=n;++i) scanf("%d",&val[i]);
for(int i=1;i<=n;++i){
if(cnt[i]){
val[i]=(1ll*val[i]*PP)%P;
}else{
b[++tot]=val[i];
}
}
sort(b+1,b+tot+1);
for(int i=1;i<=n;++i) if(!cnt[i]) val[i]=lower_bound(b+1,b+tot+1,val[i])-b;
dfs(1);
work(rt[1],1,tot);
int res=0;
for(int i=1;i<=tot;++i) res=(res+1ll*ans[i]%P*ans[i]%P*i%P*b[i]%P)%P;
printf("%d\n",res);
return 0;
}