虚树
模板
P6572void insert(int u) { int lca=getlca(u,stk[top]); while(top>1 && id[stk[top-1]]>=id[lca]) add(stk[top],stk[top-1],1),add(stk[top-1],stk[top],1),--top;//dep[stk[top-1]]>=dep[lca]也行 if(lca!=stk[top]) add(stk[top],lca,1),add(lca,stk[top],1),stk[top]=lca; stk[++top]=u; } void build() { cnt[1]=-1; sort(a+1,a+m+1,cmp); stk[top=1]=a[1]; for(int i=2; i<=m; ++i) insert(a[i]); while(top>1) add(stk[top],stk[top-1],1),add(stk[top-1],stk[top],1),--top; rt=stk[1]; }
树上差分+虚树
对于每个询问,建出虚树后,首先要找出深度最大的公共根,防止将不用的答案记在1(整棵树的根)上,导致出错
hack 数据便是第一个样例
6 3 2
1 3
2 3
3 4
6 4
4 5
4 1 3 2 5
2 6 3
2 3 2
方法:
void getroot(int u,int p){
if(vis[u])siz[u]=1;
else siz[u]=0;
for(int i=head[u];~i;i=e[i].nxt){
int v=e[i].v;
if(i>cnt)continue;
if(v==p)continue;
getroot(v,u);
siz[u]+=siz[v];
}
if(root==1 && siz[u]>=s)root=u;//这一步的解释:如果一个点涵盖所有关键点且不是根,则它一定是最小的根(因为更新答案是在回溯时,所以按照从下往上的顺序)
return;
}
P3233
~~关于虚树模板写假了调了一上午这件事~~
这道题及其~~恶心~~复杂,需要\(6\)次\(dfs\)。
- 预处理\(lca,dep,size\)
倍增处理\(fa\),因为后面除了求\(lca\)还会用到。
- 求出虚树中每个虚树节点子树中距离此节点最近的"关键节点"
所有关键节点都在虚树上,所以第一遍从下而上\(dfs\),可以求出每个节点在它子树中距离它最近的关键节点(可以是自己)。
- 求出虚树中距离每个虚树节点最近的"关键节点"
第二遍从上而下求出子树外贡献。
注意,我们维护的是一个二元组\(g[u]\),第一关键字为距离,第二关键字为编号。
- 求出虚树上每个节点距离最近的点编号,以及节点上的贡献。
最近节点就是\(num[u]=g[u].second\),而节点上的贡献是指:对于虚树节点\(u\),其所有没有关键点的子树,都将归\(num[u]\)管理。
所以,我们遍历虚树时,可以用\(u\)所在的子树总结点数减去有关键点的子树大小。
有关键点的子树大小需要在原树上倍增求出对于\(v\)所在子树的根\(up[v]\),这时\(up[v]\)在\(u\)下面,减掉\(sz[up[v]]\)即可。
- 求出虚树上每条边的贡献
对于每条边两端点\(u,v\),有两种情况:
一是\(num[u]=num[v]\),这种情况这条边上所有节点都归\(num[u]\)管即可。
二是\(num[u]\not =num[v]\),这种情况先求出中间点的深度,再从\(v\)倍增找出中间点标号即可。
具体地,从\(num[u]\)到\(num[v]\)的链长度\(L\)=\(dep[num[v]]-dep[u]+g[u].first\).
那么中间点的深度=\(\frac{dep[num[v]]-L} 2\).
- 清空虚树及点标记
需要清空\(vis[i],ans[i],up[i],num[i]\).
至此,这道题才算做完。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<utility>
using namespace std;
typedef pair<int,int> PI;
const int N=3e5+10,K=20,INF=0x3f3f3f3f;
int fa[N][K],dep[N],head[2][N],vis[N],a[N],b[N],sz[N],stk[N],ans[N],id[N],num[N],up[N];
int n,q,u,v,cnt[2],tmp,m,top,tot,rt;
PI g[N];
struct edge {
int v,nxt;
} e[2][N<<1];
void add(int u,int v,int i) {
e[i][++cnt[i]].v=v,e[i][cnt[i]].nxt=head[i][u],head[i][u]=cnt[i];
}
bool cmp(int a,int b) {
return id[a]<id[b];
}
int read1(){
int x=0;char ch=getchar();
while(ch<'0' || ch>'9') ch=getchar();
while(ch>='0' && ch<='9') x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return x;
}
void write1(int x){
if(x>9) write1(x/10);
putchar(x%10+'0');
}
void dfs0(int u,int p) {
id[u]=++tot;fa[u][0]=p;sz[u]=1;
for(int i=1; i<K; ++i) fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=head[0][u]; ~i; i=e[0][i].nxt) {
int v=e[0][i].v;
if(v==p) continue;
dep[v]=dep[u]+1;
dfs0(v,u);
sz[u]+=sz[v];
}
}
void dfs1(int u,int p) {
if(vis[u]) g[u]=make_pair(0,u);
else g[u]=make_pair(INF,0);
for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
int v=e[1][i].v;
if(v==p) continue;
dfs1(v,u);
g[u]=min(g[u],make_pair(g[v].first+dep[v]-dep[u],g[v].second));
}
}
void dfs2(int u,int p,int d,int x) {
PI tmp=make_pair(d,x);
if(tmp<g[u]) g[u]=tmp;
else d=g[u].first,x=g[u].second;
for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
int v=e[1][i].v;
if(v==p) continue;
dfs2(v,u,d+dep[v]-dep[u],x);
}
}
void dfs3(int u,int p) {
num[u]=g[u].second;
ans[num[u]]+=sz[u];
for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
int v=e[1][i].v;
if(v==p) continue;
int k=v;
for(int j=K-1;j>=0;--j) if(fa[k][j] && dep[fa[k][j]]>dep[u]) k=fa[k][j];
ans[num[u]]-=sz[up[v]=k];
dfs3(v,u);
}
}
void dfs4(int u,int p){
for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
int v=e[1][i].v;
if(v==p) continue;
if(num[v]==num[u]) ans[num[u]]+=(sz[up[v]]-sz[v]);
else{
int dis=dep[num[v]]+dep[u]-g[u].first;
dis=dis&1?dis+1>>1:(num[v]<num[u]?dis>>1:(dis>>1)+1);
int k=v;
for(int j=K-1;j>=0;--j) if(fa[k][j] && dep[fa[k][j]]>=dis) k=fa[k][j];
ans[num[u]]+=sz[up[v]]-sz[k];
ans[num[v]]+=sz[k]-sz[v];
}
dfs4(v,u);
}
}
void dfs5(int u,int p){//clear
up[u]=num[u]=0;
for(int i=head[1][u]; ~i; i=e[1][i].nxt) {
int v=e[1][i].v;
if(v==p) continue;
dfs5(v,u);
}
head[1][u]=-1;
}
int getlca(int u,int v) {
if(dep[u]<dep[v]) swap(u,v);
int k=dep[u]-dep[v];
for(int i=K-1; i>=0; --i) if(k&(1<<i)) u=fa[u][i];
if(v==u) return u;
for(int i=K-1; i>=0; --i) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
void insert(int u) {
int lca=getlca(u,stk[top]);
while(top>1 && id[stk[top-1]]>=id[lca]) add(stk[top],stk[top-1],1),add(stk[top-1],stk[top],1),--top;
if(lca!=stk[top]) add(stk[top],lca,1),add(lca,stk[top],1),stk[top]=lca;
stk[++top]=u;
}
void build() {
cnt[1]=-1;
sort(a+1,a+m+1,cmp);
stk[top=1]=a[1];
for(int i=2; i<=m; ++i) insert(a[i]);
while(top>1) add(stk[top],stk[top-1],1),add(stk[top-1],stk[top],1),--top;
rt=stk[1];
}
int main() {
memset(head,-1,sizeof head);
cnt[0]=-1;
scanf("%d",&n);
for(int i=1; i<n; ++i) scanf("%d%d",&u,&v),add(u,v,0),add(v,u,0);
dfs0(1,0);
scanf("%d",&q);
for(int i=1; i<=q; ++i) {
scanf("%d",&m);
for(int j=1; j<=m; ++j) scanf("%d",&a[j]),b[j]=a[j],vis[b[j]]=1;
build();
dfs1(rt,0);
dfs2(rt,0,g[rt].first,g[rt].second);
dfs3(rt,0);
dfs4(rt,0);
ans[num[rt]]+=sz[1]-sz[rt];
for(int j=1; j<=m; ++j) printf("%d ",ans[b[j]]),ans[b[j]]=vis[b[j]]=0;printf("\n");
dfs5(rt,0);
}
return 0;
}