Skip to content

虚树

模板

void insert(int u) {
    int lca=getlca(u,stk[top]);
    while(top>1 && id[stk[top-1]]>=id[lca]) add(stk[top],stk[top-1],1),add(stk[top-1],stk[top],1),--top;//dep[stk[top-1]]>=dep[lca]也行
    if(lca!=stk[top]) add(stk[top],lca,1),add(lca,stk[top],1),stk[top]=lca;
    stk[++top]=u;
}
void build() {
    cnt[1]=-1;
    sort(a+1,a+m+1,cmp);
    stk[top=1]=a[1];
    for(int i=2; i<=m; ++i) insert(a[i]);
    while(top>1) add(stk[top],stk[top-1],1),add(stk[top-1],stk[top],1),--top;
    rt=stk[1];
}
P6572

树上差分+虚树

对于每个询问,建出虚树后,首先要找出深度最大的公共根,防止将不用的答案记在1(整棵树的根)上,导致出错

hack 数据便是第一个样例

6 3 2
1 3
2 3
3 4
6 4
4 5
4 1 3 2 5
2 6 3
2 3 2

方法:

void getroot(int u,int p){
    if(vis[u])siz[u]=1;
    else siz[u]=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(i>cnt)continue;
        if(v==p)continue;
        getroot(v,u);
        siz[u]+=siz[v];
    }
    if(root==1 && siz[u]>=s)root=u;//这一步的解释:如果一个点涵盖所有关键点且不是根,则它一定是最小的根(因为更新答案是在回溯时,所以按照从下往上的顺序)
    return;
}
之后每次深搜到一个点,就向差分数组记录一下,子节点与当前点之间所有的边+1,做完一组询问清零,最后差分数组加起来就是答案

P3233

~~关于虚树模板写假了调了一上午这件事~~

这道题及其~~恶心~~复杂,需要\(6\)\(dfs\)

  1. 预处理\(lca,dep,size\)

倍增处理\(fa\),因为后面除了求\(lca\)还会用到。

  1. 求出虚树中每个虚树节点子树中距离此节点最近的"关键节点"

所有关键节点都在虚树上,所以第一遍从下而上\(dfs\),可以求出每个节点在它子树中距离它最近的关键节点(可以是自己)。

  1. 求出虚树中距离每个虚树节点最近的"关键节点"

第二遍从上而下求出子树外贡献。

注意,我们维护的是一个二元组\(g[u]\),第一关键字为距离,第二关键字为编号。

  1. 求出虚树上每个节点距离最近的点编号,以及节点上的贡献。

最近节点就是\(num[u]=g[u].second\),而节点上的贡献是指:对于虚树节点\(u\),其所有没有关键点的子树,都将归\(num[u]\)管理。

所以,我们遍历虚树时,可以用\(u\)所在的子树总结点数减去有关键点的子树大小。

有关键点的子树大小需要在原树上倍增求出对于\(v\)所在子树的根\(up[v]\),这时\(up[v]\)\(u\)下面,减掉\(sz[up[v]]\)即可。

  1. 求出虚树上每条边的贡献

对于每条边两端点\(u,v\),有两种情况:

一是\(num[u]=num[v]\),这种情况这条边上所有节点都归\(num[u]\)管即可。

二是\(num[u]\not =num[v]\),这种情况先求出中间点的深度,再从\(v\)倍增找出中间点标号即可。

具体地,从\(num[u]\)\(num[v]\)的链长度\(L\)=\(dep[num[v]]-dep[u]+g[u].first\).

那么中间点的深度=\(\frac{dep[num[v]]-L} 2\).

  1. 清空虚树及点标记

需要清空\(vis[i],ans[i],up[i],num[i]\).

至此,这道题才算做完。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
using namespace std;
typedef pair<int,int> PI;
const int N=3e5+10,K=20,INF=0x3f3f3f3f;
int fa[N][K],dep[N],head[2][N],vis[N],a[N],b[N],sz[N],stk[N],ans[N],id[N],num[N],up[N];
int n,q,u,v,cnt[2],tmp,m,top,tot,rt;
PI g[N];
struct edge {
    int v,nxt;
} e[2][N<<1];
void add(int u,int v,int i) {
    e[i][++cnt[i]].v=v,e[i][cnt[i]].nxt=head[i][u],head[i][u]=cnt[i];
}
bool cmp(int a,int b) {
    return id[a]<id[b];
}
int read1(){
    int x=0;char ch=getchar();
    while(ch<'0' || ch>'9') ch=getchar();
    while(ch>='0' && ch<='9') x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
    return x;
}
void write1(int x){
    if(x>9) write1(x/10);
    putchar(x%10+'0');  
}
void dfs0(int u,int p) {
    id[u]=++tot;fa[u][0]=p;sz[u]=1;
    for(int i=1; i<K; ++i) fa[u][i]=fa[fa[u][i-1]][i-1];
    for(int i=head[0][u]; ~i; i=e[0][i].nxt) {
        int v=e[0][i].v;
        if(v==p) continue;
        dep[v]=dep[u]+1;
        dfs0(v,u);
        sz[u]+=sz[v];
    }
}
void dfs1(int u,int p) {
    if(vis[u]) g[u]=make_pair(0,u);
    else g[u]=make_pair(INF,0);
    for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
        int v=e[1][i].v;
        if(v==p) continue;
        dfs1(v,u);
        g[u]=min(g[u],make_pair(g[v].first+dep[v]-dep[u],g[v].second));
    }
}
void dfs2(int u,int p,int d,int x) {
    PI tmp=make_pair(d,x);
    if(tmp<g[u]) g[u]=tmp;
    else d=g[u].first,x=g[u].second;
    for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
        int v=e[1][i].v;
        if(v==p) continue;
        dfs2(v,u,d+dep[v]-dep[u],x);
    }
}
void dfs3(int u,int p) {
    num[u]=g[u].second;
    ans[num[u]]+=sz[u];
    for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
        int v=e[1][i].v;
        if(v==p) continue;
        int k=v;
        for(int j=K-1;j>=0;--j) if(fa[k][j] && dep[fa[k][j]]>dep[u]) k=fa[k][j];
        ans[num[u]]-=sz[up[v]=k];
        dfs3(v,u);
    }
}
void dfs4(int u,int p){
    for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
        int v=e[1][i].v;
        if(v==p) continue;
        if(num[v]==num[u]) ans[num[u]]+=(sz[up[v]]-sz[v]);
        else{
            int dis=dep[num[v]]+dep[u]-g[u].first;
            dis=dis&1?dis+1>>1:(num[v]<num[u]?dis>>1:(dis>>1)+1);
            int k=v;
            for(int j=K-1;j>=0;--j) if(fa[k][j] && dep[fa[k][j]]>=dis) k=fa[k][j];
            ans[num[u]]+=sz[up[v]]-sz[k];
            ans[num[v]]+=sz[k]-sz[v];
        }
        dfs4(v,u);
    }
}
void dfs5(int u,int p){//clear
    up[u]=num[u]=0;
    for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
        int v=e[1][i].v;
        if(v==p) continue;
        dfs5(v,u);
    }
    head[1][u]=-1;
}
int getlca(int u,int v) {
    if(dep[u]<dep[v]) swap(u,v);
    int k=dep[u]-dep[v];
    for(int i=K-1; i>=0; --i) if(k&(1<<i)) u=fa[u][i];
    if(v==u) return u;
    for(int i=K-1; i>=0; --i) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
    return fa[u][0];
}
void insert(int u) {
    int lca=getlca(u,stk[top]);
    while(top>1 && id[stk[top-1]]>=id[lca]) add(stk[top],stk[top-1],1),add(stk[top-1],stk[top],1),--top;
    if(lca!=stk[top]) add(stk[top],lca,1),add(lca,stk[top],1),stk[top]=lca;
    stk[++top]=u;
}
void build() {
    cnt[1]=-1;
    sort(a+1,a+m+1,cmp);
    stk[top=1]=a[1];
    for(int i=2; i<=m; ++i) insert(a[i]);
    while(top>1) add(stk[top],stk[top-1],1),add(stk[top-1],stk[top],1),--top;
    rt=stk[1];
}
int main() {
    memset(head,-1,sizeof head);
    cnt[0]=-1;
    scanf("%d",&n);
    for(int i=1; i<n; ++i) scanf("%d%d",&u,&v),add(u,v,0),add(v,u,0);
    dfs0(1,0);
    scanf("%d",&q);
    for(int i=1; i<=q; ++i) {
        scanf("%d",&m);
        for(int j=1; j<=m; ++j) scanf("%d",&a[j]),b[j]=a[j],vis[b[j]]=1;
        build();
        dfs1(rt,0);
        dfs2(rt,0,g[rt].first,g[rt].second);
        dfs3(rt,0);
        dfs4(rt,0);
        ans[num[rt]]+=sz[1]-sz[rt];
        for(int j=1; j<=m; ++j) printf("%d ",ans[b[j]]),ans[b[j]]=vis[b[j]]=0;printf("\n");
        dfs5(rt,0);
    }
    return 0;
}