Skip to content

主席树

3939

注意动态开点的写法以及修改操作:修改只需要改前一个点,后一个点维护的值不变.

#include<iostream>
#include<cstdio>
#include<cstring>

using namespace std;
const int N=3e5+10;
int rt[N],ls[N*60],rs[N*60],a[N],tre[N*60];
int n,m,op,l,r,c,x,cnt;

void change(int &x,int x1,int l,int r,int pos,int k) {
    x=++cnt;
    rs[x]=rs[x1];//注意
    ls[x]=ls[x1];
    tre[x]=tre[x1]+k;
    if(l==r) {
        return;
    }
    int mid=l+r>>1;
    if(mid>=pos) {
        change(ls[x],ls[x1],l,mid,pos,k);
    } else {
        change(rs[x],rs[x1],mid+1,r,pos,k);
    }
    return;
}
int abs1(int x){
    return x>0?x:-x;
}
int query(int x1,int x2,int l,int r,int pos) {
    if(l==r){
        return abs1(tre[x2]-tre[x1]);   
    }
    int mid=l+r>>1;
    if(mid>=pos) {
        return query(ls[x1],ls[x2],l,mid,pos);
    } else {
        return query(rs[x1],rs[x2],mid+1,r,pos);
    }
}
int read1(){
    int x=0,f=1;
    char ch=getchar();
    while(ch>'9' || ch<'0'){
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(ch<='9' && ch>='0'){
        x=(x<<1)+(x<<3)+ch-'0';
        ch=getchar();
    }
    return x*f;
}
void write1(int x){
    if(x<0)putchar('-'),x=-x;
    if(x>9)write1(x/10);
    putchar(x%10+'0');
    return;
}
int main() {
    cnt=0;
    n=read1(),m=read1();
    for(int i=1; i<=n; ++i)a[i]=read1();
    for(int i=1; i<=n; ++i){
        change(rt[i],rt[i-1],0,N+1,a[i],1);
    }   
    for(int i=1; i<=m; ++i) {
        op=read1();
        if(op==1) {
            l=read1(),r=read1(),c=read1();
            int ans=query(rt[l-1],rt[r],0,N+1,c);
            write1(ans);
            putchar('\n');
        } else {
            x=read1();
            change(rt[x],rt[x],0,N+1,a[x],-1);
            change(rt[x],rt[x],0,N+1,a[x+1], 1);
            swap(a[x],a[x+1]);
        }
   }
    return 0;
}

P6166 & P1383

两个题差不多,都有三种操作:输入,输出和撤回。

注意撤回可以撤回撤回。~~???~~

对于每个根节点,我们可以记录一个num数组,代表这个版本的字符串中有几个字符;

所以有:

    num[rt[tot]]=num[rt[tot-1]]+1

那么对于输入操作,就直接在num[tot]的位置插入字母即可。在递归到最后一层时在val数组中记录。(l==r)

输出操作注意是查询p+1,因为p从0开始。

撤回操作直接赋值根节点即可。

code time:

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=1e6+10,M=N*21;
int n,cnt,tot,p,u;
int ls[M],rs[M],val[M],num[N],rt[N];
char ch[5];
void add(int &now,int pre,int l,int r,int pos,int k){
    now=++cnt;
    ls[now]=ls[pre];
    rs[now]=rs[pre];
    if(l==r){
        val[now]=k;
        return;
    }
    int mid=l+r>>1;
    if(pos<=mid)add(ls[now],ls[pre],l,mid,pos,k);
    else add(rs[now],rs[pre],mid+1,r,pos,k);
    return;
}
int query(int x,int l,int r,int pos){
    if(l==r)return val[x];
    int mid=l+r>>1;
    if(pos<=mid)return query(ls[x],l,mid,pos);
    else return query(rs[x],mid+1,r,pos);
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<=n;++i){
        scanf("%s",ch);
        if(ch[0]=='T'){
            scanf("%s",ch);
            int c=ch[0]-'a'+1;  
            ++tot;
            num[tot]=num[tot-1]+1;
            add(rt[tot],rt[tot-1],1,N,num[tot],c);
        }else if(ch[0]=='P'){
            scanf("%d",&p);
            int q=query(rt[tot],1,N,p+1);
            printf("%c\n",q+'a'-1);
        }else{
            scanf("%d",&u);
            tot++;
            rt[tot]=rt[tot-u-1];
            num[tot]=num[tot-u-1];
        }
    }
    return 0;
}

P2633 & SP10628

这两道题很好,有很多细节值得说说。

首先,看到在树上求两点间内容的问题,就不难想到树剖lca,而且是第k小,就要用主席树,而这道题也是一样。

只不过每次跳top时累加的不再是具体的值,而是一个区间两端对应的根节点。

这样,在整体查询时,将所有右端点加起来,所有左端点减下去,就可以组成由两点间所有权值组成的主席树了。

  • 一些细节:

1) 树剖时还是注意,dfs序要先重儿子再轻儿子。

2) 第k小仿照treap的写法,每次向右搜都将k-tmp;

3) 权值是0\~MAXINT,所以主席树数组别开小了,至少31倍.

4) b1,b2数组记录的时由dfs序得出的根节点序号,所以可以直接用。

5) 注意所有左端点都是深度低的点,并且因为差分,dfs序要-1;

6) 建主席树按照dfs序建。

code time:

#include<iostream>
#include<cstdio>
#include<cstring>
#define int long long 
using namespace std;
const int N=1e5+10,M=N*40,C=3e9;
int rt[N],ls[M],rs[M],num[M];
int n,m,cnt,tot,cur,u,v,k;
int a[N],head[N],sz[N],fa[N],son[N],dep[N],top[N],dfn[N],b1[N],b2[N],id[N];
struct node{
    int v,nxt;
}e[N<<1];
void add(int u,int v){
    e[++cnt].v=v;
    e[cnt].nxt=head[u];
    head[u]=cnt;
}
void dfs1(int u,int p){
    sz[u]=1;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p)continue;
        dep[v]=dep[u]+1;
        fa[v]=u;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[son[u]])son[u]=v;
    }
}
void dfs2(int u,int t){
    top[u]=t;
    id[u]=++tot;
    dfn[tot]=u;
    if(son[u])dfs2(son[u],t);
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==fa[u] || v==son[u])continue;
        dfs2(v,v);
    }
}
void add(int &now,int pre,int l,int r,int pos){
    now=++cur;
    ls[now]=ls[pre];
    rs[now]=rs[pre];
    num[now]=num[pre]+1;
    if(l==r){
        return;
    }
    int mid=l+r>>1;
    if(pos<=mid)add(ls[now],ls[pre],l,mid,pos);
    else add(rs[now],rs[pre],mid+1,r,pos);

}
int query(int l,int r,int k){
    if(l==r)return l;
    int tmp=0,mid=l+r>>1;
    for(int i=1;i<=b1[0];++i){
        tmp+=num[ls[b1[i]]];
    }
    for(int i=1;i<=b2[0];++i){
        tmp-=num[ls[b2[i]]];
    }
    if(k<=tmp){
        for(int i=1;i<=b1[0];++i){
            b1[i]=ls[b1[i]];
        }
        for(int i=1;i<=b2[0];++i){
            b2[i]=ls[b2[i]];
        }
        return query(l,mid,k);
    }else{
        for(int i=1;i<=b1[0];++i){
            b1[i]=rs[b1[i]];
        }
        for(int i=1;i<=b2[0];++i){
            b2[i]=rs[b2[i]];
        }
        return query(mid+1,r,k-tmp);
    }
}
int get1(int u,int v,int k){
    b1[0]=0;
    b2[0]=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        b1[++b1[0]]=rt[id[u]];
        b2[++b2[0]]=rt[id[top[u]]-1];
        u=fa[top[u]];
    }
    if(dep[u]<dep[v])swap(u,v);
    b1[++b1[0]]=rt[id[u]];
    b2[++b2[0]]=rt[id[v]-1];
    return query(0,C,k);
}
signed main(){
    memset(head,-1,sizeof head);
    cnt=-1;
    scanf("%lld%lld",&n,&m);
    for(int i=1;i<=n;++i){
        scanf("%lld",&a[i]);
    }
    for(int i=1;i<n;++i){
        scanf("%lld%lld",&u,&v);
        add(u,v);
        add(v,u);
    }
    tot=0;
    dfs1(1,1);
    dfs2(1,1);
    cur=0;
    for(int i=1;i<=tot;++i){
        add(rt[i],rt[i-1],0,C,a[dfn[i]]);
    }
    int last=0;
    for(int i=1;i<=m;++i){
        scanf("%lld%lld%lld",&u,&v,&k);
        u^=last;
        int tmp=get1(u,v,k);
        printf("%lld\n",tmp);
        last=tmp;   
    }
    return 0;
}
~~Q.如何增加做题量?A.双倍经验!~~