Skip to content

点分治

主要思想

因为分治后,我们希望层数尽量小,也就是选择所有子树尽量平衡的点。所以,我们贪心的取重心。可以证明,因为重心的子树不超过\(\frac n2\),所以层数不超过\(O(\log_2n)\)

那么对于每一层的每个重心,我们需要求出当前重心所在子树的所有距离,可以用桶记录;而统计点对时,这里面包括了两个点在同一颗子树的情况(不是一条简单路径),那么我们运用容斥原理,减去分裂出的子树中的贡献即可。

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=2e4+10,K=2e6+10;
int ans[K],dis[N],head[N],sz[N],vis[N],maxp[N],d[N];
int n,m,cnt,u,v,w,k,tot;
struct node{
    int v,w,nxt;
}e[N<<1];
void add(int u,int v,int w){
    e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void __dfs(int u,int p){
    sz[u]=1;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        __dfs(v,u);
        sz[u]+=sz[v];
    }
}
void dfs(int u,int p,int S,int &rt,int &tmp){
    maxp[u]=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        dfs(v,u,S,rt,tmp);
        maxp[u]=max(maxp[u],sz[v]);
    }
    maxp[u]=max(maxp[u],S-sz[u]);
    if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void _dfs(int u,int p){
    d[++tot]=dis[u];
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        dis[v]=dis[u]+w;
        _dfs(v,u); 
    }
}
void solve(int u,int len,int t){
    dis[u]=len;tot=0;
    _dfs(u,0);
    for(int i=1;i<=tot;++i)
        for(int j=1;j<=tot;++j)
            if(i!=j && d[i]+d[j]<K) ans[d[i]+d[j]]+=t;
}
void div(int u){
    solve(u,0,1);vis[u]=1;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(vis[v])continue;
        solve(v,w,-1);
        int rt=0,tmp=K;
        __dfs(v,u);
        dfs(v,u,sz[u],rt,tmp);
        div(rt);
    }
}
int main(){
    memset(head,-1,sizeof head);cnt=-1;
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;++i) scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
    int rt=0,tmp=K;
    __dfs(1,0);
    dfs(1,0,n,rt,tmp);
    div(rt);
    for(int i=1;i<=m;++i){
        scanf("%d",&k);
        if(ans[k]>0 || k==0) printf("Yes\n");
        else printf("No\n");
    }
    //for(int i=0;i<=10;++i) cout<<ans[i]<<" ";

    return 0;
} 

~~然后你会发现,bzoj1316:Accepted,luoguP3806:Time Limit Exceeded~~

下面是常见的问题.

1) 有关错误的求重心对复杂度的影响

每次求中心之前一定要先\(dfs\)一遍整颗子树,求出大小后再用这个节点数取更新重心。不然可能会出现这样的情况:

//对于这样一颗树
11//n
6 7 1//u v w
6 8 1
7 9 1
7 10 1
8 11 1
1 2 1
1 3 1
2 4 1
2 5 1
3 6 1
在分治到\(6\)号点时,因为此时\(sz[3]=7\),所以求出的重心为\(1\);而事实上\(sz[3]=5\),求出重心为\(2\).

2) 有关常规做法错误的复杂度

本来分治的层数是严格小于\(O(\log n)\)这个上界的,每一层最多遍历一遍所有节点,所以复杂度\(O(n\log n)\).但是,朴素的统计有个\(O(n^2)\)的统计,对于菊花图这样子节点巨多的数据一定会T.所以,我们可以采用双指针的方式统计。

\(dfs\)一边求出当前子树中节点深度\(d[i]\),所有节点标号\(a[i]\),以及节点处在哪个子树中\(b[i]\)

对于深度从小到大排序(不要忘记加入重心\(u\)),然后对于不同的询问,\(l=1,r=tot\)地扫一遍,注意要满足\(b[a[l]]\not=b[a[r]]\),去掉了容斥。分治+排序复杂度\(O(n\log^2n)\),分治+双指针复杂度\(O(nm\log n)\),总\(O(n\log^2n + nm\log n)\)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e4+10,K=1e7+10;
int query[N],head[N],sz[N],vis[N],maxp[N],d[N],a[N],b[N];
int n,m,cnt,u,v,w,k,tot;
bool ans[K];
bool cmp(int x,int y){
    return d[x]<d[y];
}
struct node{
    int v,w,nxt;
}e[N<<1];
void add(int u,int v,int w){
    e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void __dfs(int u,int p){
    sz[u]=1;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        __dfs(v,u);
        sz[u]+=sz[v];
    }
}
void dfs(int u,int p,int S,int &rt,int &tmp){
    maxp[u]=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        dfs(v,u,S,rt,tmp);
        maxp[u]=max(maxp[u],sz[v]);
    }
    maxp[u]=max(maxp[u],S-sz[u]);
    if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void _dfs(int u,int p,int dis,int P){
    a[++tot]=u;d[u]=dis;b[u]=P;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        _dfs(v,u,dis+w,P); 
    }
}
void solve(int u){
    tot=0;
    a[++tot]=u;
    b[u]=u;d[u]=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(vis[v])continue;
        _dfs(v,u,w,v);  
    }
    sort(a+1,a+tot+1,cmp);
    for(int i=1;i<=m;++i){
        int l=1,r=tot;
        if(ans[i]) continue;
        while(l<r){
            //cout<<u<<" "<<i<<": "<<a[l]<<" "<<a[r]<<" "<<d[a[l]]<<" "<<d[a[r]]<<endl;
            if(d[a[l]]+d[a[r]]>query[i]) --r;
            else if(d[a[l]]+d[a[r]]<query[i]) ++l;
            else if(b[a[l]]==b[a[r]]){
                if(d[a[r]]==d[a[r-1]]) --r;
                else ++l;
            } 
            else {
                ans[i]=1;break;
            }
        }
    }
}
void div(int u){
    vis[u]=1;solve(u);
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(vis[v])continue;
        int rt=0,tmp=K;
        __dfs(v,u);
        dfs(v,u,sz[v],rt,tmp);
        div(rt);
    }
}
int main(){
    memset(head,-1,sizeof head);cnt=-1;
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;++i) scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
    for(int i=1;i<=m;++i){
        scanf("%d",&query[i]);
        if(!query[i]) ans[i]=1;
    }
    int rt=0,tmp=K;
    __dfs(1,0);
    dfs(1,0,sz[1],rt,tmp);
    div(rt);
    for(int i=1;i<=m;++i){
        if(ans[i]) printf("AYE\n");
        else printf("NAY\n");
    }
    return 0;
} 

/*
7 1
1 6 13 
6 3 9 
3 5 7 
4 1 3 
2 4 20 
4 7 2 
10

6 4
1 2 5
1 3 7
1 4 1
3 5 2
3 6 3
0
8
13
14


11 1
6 7 1
6 8 1
7 9 1
7 10 1
8 11 1
1 2 1
1 3 1
2 4 1
2 5 1
3 6 1
2
*/
...然而,容斥真的毫无用武之地吗?

P4178

这道题只有一个询问,然而却变成\(\leq k\)地点对数量。这时我们就可以用容斥。当然,双指针是必需的。

对于每次分治计算时,还是将\(a[i]\)按照\(d[a[i]]\)排序.然后,双指针\(l=1,r=tot\)从两头向中间扫。每次满足\(d[a[l]]+d[a[r]]<=k\)时就\(ans+=r-l\).而这些重复的点对用容斥即可解决。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=4e4+10,K=2e4+10,INF=1e8+10;
struct edge{
    int v,w,nxt;
}e[N<<1];
int head[N],sz[N],vis[N],maxp[N],a[N],b[N],d[N];
int cnt,n,u,v,w,tot,ans,k;
bool cmp(int x,int y){
    return d[x]<d[y];
}
void add(int u,int v,int w){
    e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void get_size(int u,int p){
    sz[u]=1;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_size(v,u);
        sz[u]+=sz[v];
    }
}
void get_root(int u,int p,int S,int &rt,int &tmp){
    maxp[u]=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_root(v,u,S,rt,tmp);
        maxp[u]=max(maxp[u],sz[v]);
    }
    maxp[u]=max(maxp[u],S-sz[u]);
    if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void get_dis(int u,int p,int dis){
    a[++tot]=u;
    d[u]=dis;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_dis(v,u,dis+w);
    }
}
int calc(int u,int len){
    tot=0;get_dis(u,0,len);
    sort(a+1,a+tot+1,cmp);
    int l=1,r=tot,res=0;
    while(l<r){
        while(d[a[l]]+d[a[r]]>k && l<r) --r;
        res+=r-l;
        ++l;
    }
    return res;
}
void div(int u){
    vis[u]=1;
    ans+=calc(u,0);
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(vis[v]) continue;
        ans-=calc(v,w);
        int rt=0,tmp=INF;
        get_size(v,u);
        get_root(v,u,sz[v],rt,tmp);
        div(rt);
    }
}
int main(){
    memset(head,-1,sizeof head);cnt=-1;
    scanf("%d",&n);
    for(int i=1;i<n;++i) scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
    scanf("%d",&k);ans=0;

    int rt=0,tmp=INF;
    get_size(1,0);
    get_root(1,0,sz[1],rt,tmp);
    div(rt);

    printf("%d",ans);
    return 0;
}

P2634

树形\(dp\)可能会更简单一些,不过点分治也可做。

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=2e4+10,INF=1e8+10;
struct edge{
    int v,w,nxt;
}e[N<<1];
int cnt,ans,n,u,v,w;
int gcd(int a,int b){
    if(a<b) swap(a,b);
    return !b?a:gcd(b,a%b);
}
int head[N],vis[N],sz[N],maxp[N],f[N],d[N];
void add(int u,int v,int w){
    e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}

void get_size(int u,int p){
    sz[u]=1;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_size(v,u);
        sz[u]+=sz[v];
    }
}
void get_root(int u,int p,int S,int &rt,int &tmp){
    maxp[u]=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_root(v,u,S,rt,tmp);
        maxp[u]=max(maxp[u],sz[v]);
    }
    maxp[u]=max(maxp[u],S-sz[u]);
    if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void get_dis(int u,int p,int dis){
    d[u]=dis%3;
    ++f[d[u]];
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_dis(v,u,(dis+w)%3);
    }
}
int calc(int u,int len){
    f[0]=f[1]=f[2]=0;
    get_dis(u,0,len);
    return f[0]*f[0]+f[1]*f[2]*2;
}
void div(int u){
    vis[u]=1;ans+=calc(u,0);
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(vis[v]) continue;
        ans-=calc(v,w);
        int rt=0,tmp=INF;
        get_size(v,u);
        get_root(v,u,sz[v],rt,tmp);
        div(rt);
    }
}
int main(){
    memset(head,-1,sizeof head);cnt=-1;
    scanf("%d",&n);
    for(int i=1;i<n;++i) scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
    int rt=0,tmp=INF;
    get_size(1,0);
    get_root(1,0,sz[1],rt,tmp);
    div(rt);
    int g=gcd(ans,n*n);
    printf("%d/%d",ans/g,n*n/g);
    return 0;
}

P4149

点分治的裸题,与P3806相似,也是用双指针维护一下即可。

#include<iostream>
#include<cstdio> 
#include<cstring>
#include<algorithm>
#define int long long 
using namespace std;
const int N=2e5+10,INF=1e14;

int n,k,u,v,rt,tmp,minn,tot,cnt,w;
int head[N],sz[N],vis[N],maxp[N],d[N],a[N],b[N],c[N];
struct node{
    int v,w,nxt;
}e[N<<1];
bool cmp(int x,int y){
    return d[x]<d[y];
}
void add(int u,int v,int w){
    e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void get_size(int u,int p){
    sz[u]=1;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_size(v,u);
        sz[u]+=sz[v];
    }
}
void get_dis(int u,int p,int dis,int dep,int P){
    a[++tot]=u;d[u]=dis;b[u]=P;c[u]=dep;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_dis(v,u,dis+w,dep+1,P);
    }
}
void get_root(int u,int p,int S,int &rt,int &tmp){
    maxp[u]=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(v==p || vis[v])continue;
        get_root(v,u,S,rt,tmp);
        maxp[u]=max(maxp[u],sz[v]);
    }
    maxp[u]=max(maxp[u],S-sz[u]);
    if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void calc(int u){
    tot=0;a[++tot]=u;d[u]=0;b[u]=u;c[u]=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(vis[v]) continue;
        get_dis(v,u,w,1,v);
    }
    sort(a+1,a+tot+1,cmp);
    int l=1,r=tot;
    while(l<r){
        if(d[a[l]]+d[a[r]]>k) --r;
        else if(d[a[l]]+d[a[r]]<k) ++l;
        else{
            if(b[a[l]]!=b[a[r]]) minn=min(minn,c[a[l]]+c[a[r]]);
            if(d[a[r]]==d[a[r-1]]) --r;
            else ++l;           
        }
    }
}
void div(int u){
    vis[u]=1;calc(u);
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;
        if(vis[v])continue;
        rt=0,tmp=INF;
        get_size(v,u);
        get_root(v,u,sz[v],rt,tmp);
        div(rt);
    }
    return;
}
signed main(){
    memset(head,-1,sizeof head);cnt=-1;minn=INF;
    scanf("%lld%lld",&n,&k);
    for(int i=1;i<n;++i) scanf("%lld%lld%lld",&u,&v,&w),add(u+1,v+1,w),add(v+1,u+1,w);
    rt=0,tmp=INF;
    get_size(1,0);
    get_root(1,0,sz[1],rt,tmp);
    div(rt);

    if(minn==INF) printf("-1");
    else printf("%lld",minn);
    return 0;
}

~~然后你会发现,WA#7~~

看了讨论区才明白,是排序出了问题,改成return (d[x]==d[y])?c[x]<c[y]:d[x]<d[y]就行了。

考虑一个中间状态的\(hack\)数据。

d: 3 3 ... 5 5
c: 6 4 ... 7 3
那么我们会搜到的是\(6+3,6+7,4+7\),却没有搜到\(4+3\),所以出错了。

说白了,在\(d[i]\)相等的情况下,指针不会回退,此时必须要将\(c[i]\)排序。