点分治
主要思想
因为分治后,我们希望层数尽量小,也就是选择所有子树尽量平衡的点。所以,我们贪心的取重心。可以证明,因为重心的子树不超过\(\frac n2\),所以层数不超过\(O(\log_2n)\)
那么对于每一层的每个重心,我们需要求出当前重心所在子树的所有距离,可以用桶记录;而统计点对时,这里面包括了两个点在同一颗子树的情况(不是一条简单路径),那么我们运用容斥原理,减去分裂出的子树中的贡献即可。
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=2e4+10,K=2e6+10;
int ans[K],dis[N],head[N],sz[N],vis[N],maxp[N],d[N];
int n,m,cnt,u,v,w,k,tot;
struct node{
int v,w,nxt;
}e[N<<1];
void add(int u,int v,int w){
e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void __dfs(int u,int p){
sz[u]=1;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
__dfs(v,u);
sz[u]+=sz[v];
}
}
void dfs(int u,int p,int S,int &rt,int &tmp){
maxp[u]=0;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
dfs(v,u,S,rt,tmp);
maxp[u]=max(maxp[u],sz[v]);
}
maxp[u]=max(maxp[u],S-sz[u]);
if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void _dfs(int u,int p){
d[++tot]=dis[u];
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
dis[v]=dis[u]+w;
_dfs(v,u);
}
}
void solve(int u,int len,int t){
dis[u]=len;tot=0;
_dfs(u,0);
for(int i=1;i<=tot;++i)
for(int j=1;j<=tot;++j)
if(i!=j && d[i]+d[j]<K) ans[d[i]+d[j]]+=t;
}
void div(int u){
solve(u,0,1);vis[u]=1;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(vis[v])continue;
solve(v,w,-1);
int rt=0,tmp=K;
__dfs(v,u);
dfs(v,u,sz[u],rt,tmp);
div(rt);
}
}
int main(){
memset(head,-1,sizeof head);cnt=-1;
scanf("%d%d",&n,&m);
for(int i=1;i<n;++i) scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
int rt=0,tmp=K;
__dfs(1,0);
dfs(1,0,n,rt,tmp);
div(rt);
for(int i=1;i<=m;++i){
scanf("%d",&k);
if(ans[k]>0 || k==0) printf("Yes\n");
else printf("No\n");
}
//for(int i=0;i<=10;++i) cout<<ans[i]<<" ";
return 0;
}
~~然后你会发现,bzoj1316:Accepted,luoguP3806:Time Limit Exceeded~~
下面是常见的问题.
1) 有关错误的求重心对复杂度的影响
每次求中心之前一定要先\(dfs\)一遍整颗子树,求出大小后再用这个节点数取更新重心。不然可能会出现这样的情况:
//对于这样一颗树
11//n
6 7 1//u v w
6 8 1
7 9 1
7 10 1
8 11 1
1 2 1
1 3 1
2 4 1
2 5 1
3 6 1
2) 有关常规做法错误的复杂度
本来分治的层数是严格小于\(O(\log n)\)这个上界的,每一层最多遍历一遍所有节点,所以复杂度\(O(n\log n)\).但是,朴素的统计有个\(O(n^2)\)的统计,对于菊花图这样子节点巨多的数据一定会T.所以,我们可以采用双指针的方式统计。
先\(dfs\)一边求出当前子树中节点深度\(d[i]\),所有节点标号\(a[i]\),以及节点处在哪个子树中\(b[i]\)。
对于深度从小到大排序(不要忘记加入重心\(u\)),然后对于不同的询问,\(l=1,r=tot\)地扫一遍,注意要满足\(b[a[l]]\not=b[a[r]]\),去掉了容斥。分治+排序复杂度\(O(n\log^2n)\),分治+双指针复杂度\(O(nm\log n)\),总\(O(n\log^2n + nm\log n)\)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e4+10,K=1e7+10;
int query[N],head[N],sz[N],vis[N],maxp[N],d[N],a[N],b[N];
int n,m,cnt,u,v,w,k,tot;
bool ans[K];
bool cmp(int x,int y){
return d[x]<d[y];
}
struct node{
int v,w,nxt;
}e[N<<1];
void add(int u,int v,int w){
e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void __dfs(int u,int p){
sz[u]=1;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
__dfs(v,u);
sz[u]+=sz[v];
}
}
void dfs(int u,int p,int S,int &rt,int &tmp){
maxp[u]=0;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
dfs(v,u,S,rt,tmp);
maxp[u]=max(maxp[u],sz[v]);
}
maxp[u]=max(maxp[u],S-sz[u]);
if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void _dfs(int u,int p,int dis,int P){
a[++tot]=u;d[u]=dis;b[u]=P;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
_dfs(v,u,dis+w,P);
}
}
void solve(int u){
tot=0;
a[++tot]=u;
b[u]=u;d[u]=0;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(vis[v])continue;
_dfs(v,u,w,v);
}
sort(a+1,a+tot+1,cmp);
for(int i=1;i<=m;++i){
int l=1,r=tot;
if(ans[i]) continue;
while(l<r){
//cout<<u<<" "<<i<<": "<<a[l]<<" "<<a[r]<<" "<<d[a[l]]<<" "<<d[a[r]]<<endl;
if(d[a[l]]+d[a[r]]>query[i]) --r;
else if(d[a[l]]+d[a[r]]<query[i]) ++l;
else if(b[a[l]]==b[a[r]]){
if(d[a[r]]==d[a[r-1]]) --r;
else ++l;
}
else {
ans[i]=1;break;
}
}
}
}
void div(int u){
vis[u]=1;solve(u);
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(vis[v])continue;
int rt=0,tmp=K;
__dfs(v,u);
dfs(v,u,sz[v],rt,tmp);
div(rt);
}
}
int main(){
memset(head,-1,sizeof head);cnt=-1;
scanf("%d%d",&n,&m);
for(int i=1;i<n;++i) scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
for(int i=1;i<=m;++i){
scanf("%d",&query[i]);
if(!query[i]) ans[i]=1;
}
int rt=0,tmp=K;
__dfs(1,0);
dfs(1,0,sz[1],rt,tmp);
div(rt);
for(int i=1;i<=m;++i){
if(ans[i]) printf("AYE\n");
else printf("NAY\n");
}
return 0;
}
/*
7 1
1 6 13
6 3 9
3 5 7
4 1 3
2 4 20
4 7 2
10
6 4
1 2 5
1 3 7
1 4 1
3 5 2
3 6 3
0
8
13
14
11 1
6 7 1
6 8 1
7 9 1
7 10 1
8 11 1
1 2 1
1 3 1
2 4 1
2 5 1
3 6 1
2
*/
P4178
这道题只有一个询问,然而却变成\(\leq k\)地点对数量。这时我们就可以用容斥。当然,双指针是必需的。
对于每次分治计算时,还是将\(a[i]\)按照\(d[a[i]]\)排序.然后,双指针\(l=1,r=tot\)从两头向中间扫。每次满足\(d[a[l]]+d[a[r]]<=k\)时就\(ans+=r-l\).而这些重复的点对用容斥即可解决。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=4e4+10,K=2e4+10,INF=1e8+10;
struct edge{
int v,w,nxt;
}e[N<<1];
int head[N],sz[N],vis[N],maxp[N],a[N],b[N],d[N];
int cnt,n,u,v,w,tot,ans,k;
bool cmp(int x,int y){
return d[x]<d[y];
}
void add(int u,int v,int w){
e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void get_size(int u,int p){
sz[u]=1;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_size(v,u);
sz[u]+=sz[v];
}
}
void get_root(int u,int p,int S,int &rt,int &tmp){
maxp[u]=0;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_root(v,u,S,rt,tmp);
maxp[u]=max(maxp[u],sz[v]);
}
maxp[u]=max(maxp[u],S-sz[u]);
if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void get_dis(int u,int p,int dis){
a[++tot]=u;
d[u]=dis;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_dis(v,u,dis+w);
}
}
int calc(int u,int len){
tot=0;get_dis(u,0,len);
sort(a+1,a+tot+1,cmp);
int l=1,r=tot,res=0;
while(l<r){
while(d[a[l]]+d[a[r]]>k && l<r) --r;
res+=r-l;
++l;
}
return res;
}
void div(int u){
vis[u]=1;
ans+=calc(u,0);
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(vis[v]) continue;
ans-=calc(v,w);
int rt=0,tmp=INF;
get_size(v,u);
get_root(v,u,sz[v],rt,tmp);
div(rt);
}
}
int main(){
memset(head,-1,sizeof head);cnt=-1;
scanf("%d",&n);
for(int i=1;i<n;++i) scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
scanf("%d",&k);ans=0;
int rt=0,tmp=INF;
get_size(1,0);
get_root(1,0,sz[1],rt,tmp);
div(rt);
printf("%d",ans);
return 0;
}
P2634
树形\(dp\)可能会更简单一些,不过点分治也可做。
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=2e4+10,INF=1e8+10;
struct edge{
int v,w,nxt;
}e[N<<1];
int cnt,ans,n,u,v,w;
int gcd(int a,int b){
if(a<b) swap(a,b);
return !b?a:gcd(b,a%b);
}
int head[N],vis[N],sz[N],maxp[N],f[N],d[N];
void add(int u,int v,int w){
e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void get_size(int u,int p){
sz[u]=1;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_size(v,u);
sz[u]+=sz[v];
}
}
void get_root(int u,int p,int S,int &rt,int &tmp){
maxp[u]=0;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_root(v,u,S,rt,tmp);
maxp[u]=max(maxp[u],sz[v]);
}
maxp[u]=max(maxp[u],S-sz[u]);
if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void get_dis(int u,int p,int dis){
d[u]=dis%3;
++f[d[u]];
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_dis(v,u,(dis+w)%3);
}
}
int calc(int u,int len){
f[0]=f[1]=f[2]=0;
get_dis(u,0,len);
return f[0]*f[0]+f[1]*f[2]*2;
}
void div(int u){
vis[u]=1;ans+=calc(u,0);
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(vis[v]) continue;
ans-=calc(v,w);
int rt=0,tmp=INF;
get_size(v,u);
get_root(v,u,sz[v],rt,tmp);
div(rt);
}
}
int main(){
memset(head,-1,sizeof head);cnt=-1;
scanf("%d",&n);
for(int i=1;i<n;++i) scanf("%d%d%d",&u,&v,&w),add(u,v,w),add(v,u,w);
int rt=0,tmp=INF;
get_size(1,0);
get_root(1,0,sz[1],rt,tmp);
div(rt);
int g=gcd(ans,n*n);
printf("%d/%d",ans/g,n*n/g);
return 0;
}
P4149
点分治的裸题,与P3806相似,也是用双指针维护一下即可。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define int long long
using namespace std;
const int N=2e5+10,INF=1e14;
int n,k,u,v,rt,tmp,minn,tot,cnt,w;
int head[N],sz[N],vis[N],maxp[N],d[N],a[N],b[N],c[N];
struct node{
int v,w,nxt;
}e[N<<1];
bool cmp(int x,int y){
return d[x]<d[y];
}
void add(int u,int v,int w){
e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void get_size(int u,int p){
sz[u]=1;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_size(v,u);
sz[u]+=sz[v];
}
}
void get_dis(int u,int p,int dis,int dep,int P){
a[++tot]=u;d[u]=dis;b[u]=P;c[u]=dep;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_dis(v,u,dis+w,dep+1,P);
}
}
void get_root(int u,int p,int S,int &rt,int &tmp){
maxp[u]=0;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==p || vis[v])continue;
get_root(v,u,S,rt,tmp);
maxp[u]=max(maxp[u],sz[v]);
}
maxp[u]=max(maxp[u],S-sz[u]);
if(maxp[u]<tmp) tmp=maxp[u],rt=u;
}
void calc(int u){
tot=0;a[++tot]=u;d[u]=0;b[u]=u;c[u]=0;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(vis[v]) continue;
get_dis(v,u,w,1,v);
}
sort(a+1,a+tot+1,cmp);
int l=1,r=tot;
while(l<r){
if(d[a[l]]+d[a[r]]>k) --r;
else if(d[a[l]]+d[a[r]]<k) ++l;
else{
if(b[a[l]]!=b[a[r]]) minn=min(minn,c[a[l]]+c[a[r]]);
if(d[a[r]]==d[a[r-1]]) --r;
else ++l;
}
}
}
void div(int u){
vis[u]=1;calc(u);
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(vis[v])continue;
rt=0,tmp=INF;
get_size(v,u);
get_root(v,u,sz[v],rt,tmp);
div(rt);
}
return;
}
signed main(){
memset(head,-1,sizeof head);cnt=-1;minn=INF;
scanf("%lld%lld",&n,&k);
for(int i=1;i<n;++i) scanf("%lld%lld%lld",&u,&v,&w),add(u+1,v+1,w),add(v+1,u+1,w);
rt=0,tmp=INF;
get_size(1,0);
get_root(1,0,sz[1],rt,tmp);
div(rt);
if(minn==INF) printf("-1");
else printf("%lld",minn);
return 0;
}
~~然后你会发现,WA#7~~
看了讨论区才明白,是排序出了问题,改成return (d[x]==d[y])?c[x]<c[y]:d[x]<d[y]
就行了。
考虑一个中间状态的\(hack\)数据。
d: 3 3 ... 5 5
c: 6 4 ... 7 3
说白了,在\(d[i]\)相等的情况下,指针不会回退,此时必须要将\(c[i]\)排序。