Skip to content

树的直径

有一个重要性质:到树上的点的距离最大的点一定是直径的两个端点之一。

P1099 & 2491

对于每个直径上的路径来说,两个端点的最远距离一定是到对应的直径端点的距离(注意是对应的直径端点,不能经过路径)。而这个最大值不一定就是最终答案,因为路径上的其他点也有贡献。

\(f[i]\)表示不经过直径的,到\(i\)号点最远的距离。

那么路径上的点就是这些点的\(f[i]\)

又因为每次将一个点归并为路径后,它的答案只会由到直径变为\(f[i]\),即不会再变大,所以可以证明路径长度越长越好,即刚好小于等于\(s\)

这样,就可以用尺取法+单调队列,每次枚举一个左端点\(i\),计算出长度\(\leq s\)的右端点\(j\),通过单调队列求出区间\(f[i]\)的最大值,与两端点的直径距离比较即可。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#define int long long 
using namespace std;
const int N=4e5+10,INF=1e13;
int n,s,u,v,w,cnt,top,maxn,S,T,D,tmp;
int head[N],dis[N],d[N],stk[N],ans[N],fa[N],f[N],vis[N];
struct edge{
    int v,w,nxt;
}e[N<<1];
struct node{
    int num,dis;
    node(int Num,int Dis){
        num=Num,dis=Dis;
    }
};
void add(int u,int v,int w){
    e[++cnt].v=v,e[cnt].w=w,e[cnt].nxt=head[u],head[u]=cnt;
}
void dfs(int u,int p,int &a){
    if(dis[u]>maxn)maxn=dis[u],a=u;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;if(v==p)continue;
        dis[v]=dis[u]+w;
        dfs(v,u,a);
    }
}
void _dfs(int u,int p){
    stk[++top]=u;
    if(u==T){
        ans[0]=top;
        for(int i=1;i<=top;++i)ans[i]=stk[i];
        return; 
    }
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;if(v==p)continue;
        fa[v]=u;
        _dfs(v,u);
    }
    --top;
}
void __dfs(int u,int p){
    maxn=max(maxn,dis[u]);
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v,w=e[i].w;if(v==p || vis[v])continue;
        dis[v]=dis[u]+w;
        __dfs(v,u);
    }
}
void init(){
    for(int i=T;i;i=fa[i])vis[i]=1;
    for(int i=T;i;i=fa[i]){// get every node on diameter T~S
        maxn=0;dis[i]=0;__dfs(i,0);f[i]=maxn;
    }
}
inline int Abs(int x){return x>=0?x:-x;}
void two_pointers(){
    deque<node> q;
    int i=T,j=T,lst=T,ans=INF,tmp=0;
    for(i=T;i;i=fa[i]){
        while(fa[j] && Abs(d[i]-d[fa[j]])<=s)j=fa[j];   
        for(int k=fa[lst];k!=fa[j];k=fa[k]){
            while(!q.empty() && q.back().dis<f[k]) q.pop_back();
            q.push_back(node(k,f[k]));
        }
        tmp=0;
        if(!q.empty())tmp=q.front().dis;
        tmp=max(tmp,max(Abs(d[T]-d[i]),Abs(d[j]-d[S])));
        ans=min(ans,tmp);

        if(!q.empty() && q.front().num==i)q.pop_front();//注意一定是只有num==i时才需要弹出左端点。
        lst=j;
    }
    printf("%lld",ans);
}
signed main(){
    memset(head,-1,sizeof head);cnt=-1;
    scanf("%lld%lld",&n,&s);
    for(int i=1;i<n;++i){
        scanf("%lld%lld%lld",&u,&v,&w);
        add(u,v,w),add(v,u,w);
    }
    maxn=0,S=0;
    dis[1]=0;dfs(1,0,S);
    maxn=0,T=0;
    dis[S]=0;dfs(S,0,T);

    D=maxn;
    top=0;_dfs(S,0);
    for(int i=1;i<=ans[0];++i)d[ans[i]]=dis[ans[i]];
    init();
    two_pointers();
    return 0;
} 
~~STL居然没爆炸!?看来线性复杂度的题做起来还是STL比较香~~