Skip to content

树形dp总结

统计答案类

P3047

dp[u][k]表示距离u不超过k的点权值之和

dfs1:$ dp[u][k]+=\sum dp[v][k-1]$ 下行,不需要容斥,直接由儿子累加即可

求出u的子树所有不超过k的点权值之和

dfs2:\(dp[v][k]+=dp[u][k-1]-(k>=2)dp[v][k-2]\)上行,需要简单容斥,并且注意要逆序,因为若dp[v][k-2]先更改,则dp[v][k]会出错

求出最终答案

P2986

dp[u]表示每个点的答案

f[u]表示该节点所有子节点到它的距离之和

f[u]=f[v]+w * sz[v]

则有dp[1]=f[1]

那么 dp[v]=dp[u]-sz[v] * w+(sz[1] * w-sz[v] * w)

=dp[u]-2 * sz[v] * w+sz[1] * w

转移时只更改的是与该边有关联的答案。

注意初始状态为dp[1]

CF9D

f[i][j]表示i个节点,深度不超过j的二叉树总数量,则

$f[i][j]=\sum_{k=1}^{i-1}f[k][j-1]* f[i-k-1][j-1] $

P1623

dp+高精度

设f[i][0/1]为最大匹配个数,0表示与儿子不匹配,1表示与其中一个儿子匹配;g[i][0/1]为对应的方案数

转移方程:(有'亿'点复杂)

//无高精度
void dp1(int u,int p){
    int sum=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p)continue;
        dp1(v,u);
        sum+=max(f[v][0],f[v][1]);
    }
    f[u][0]=sum;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p)continue;
        f[u][1]=max(f[u][1],sum-max(f[v][0],f[v][1])+f[v][0]+1);
    }
    return;
}
void dp2(int u,int p){
    int sum=1;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p)continue;
        dp2(v,u);
        if(f[v][0]<f[v][1]){
            sum*=g[v][1];
        }else if(f[v][0]>f[v][1]){
            sum*=g[v][0];
        }else{
            sum*=(g[v][0]+g[v][1]);         
        }
    }
    g[u][0]=sum;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        int tmp=1;
        if(v==p)continue;
        if(f[v][0]<f[v][1]){
            tmp=g[v][1];
        }else if(f[v][0]>f[v][1]){
            tmp=g[v][0];
        }else{
            tmp=(g[v][0]+g[v][1]);          
        }
        if(f[u][1]==f[u][0]-max(f[v][0],f[v][1])+f[v][0]+1){
            g[u][1]+=(g[u][0]/tmp)*g[v][0];
        }
    }
    return;
}
注意最后的g也要根据f判断,这个坑了我一次 ~~(我甚至连高精度都没写,只拿了60pts)~~

节点选择类(染色类)

P4084

树形dp入门题:)

dp[i][j]表示当前节点为j颜色是的方案数

注意当该节点已经有颜色时,则dp[u][c[u]]=1,其他为0 ;否则dp[u][1,2,3]=1;

转移:

//1.c[u]!=0
for(i;1-3)
if(i!=c[u])sum+=dp[v][i];
dp[u][c[u]]=(dp[u][c[u]]*sum)%P;
//2.c[u]==0
for(i;1-3)
for(j;1-3)
if(j!=i)sum+=dp[v][j];
dp[u][i]=(dp[u][i]*sum)%P;

P2279 非常好。非常恶心

dp[u][0,1,2,3,4]一共0,1,2,3,4五种状态,分别为爷爷,父亲,自己,儿子,孙子

同样要注意上行的状态dp[u][3] 和dp[u][4]

P2899

dp[u][0/1/2]表示三种状态,被自己覆盖,被儿子覆盖,被父亲覆盖

注意顺序时从下往上回溯时搜,所以搜到父亲时儿子已经被更新完,可以直接用

1.dp[u][0]=\(\sum\)min(dp[v][0],dp[v][1],dp[v][2]),因为自己已经被自己覆盖,所以可以取儿子的所有值;

2.dp[u][2]=\(\sum\)min(dp[v][0],dp[v][1]),如果u被fa[u]覆盖,则u不会被u自己,也就是dp[v][2]覆盖,就不取了

update:2021.11.12 应该是dp[u][2]=\(\sum\)dp[v][1],因为u只被父亲覆盖,如果取dp[v][0]则表明u还被儿子覆盖,则会出错.

3.< 重点 >dp[u][1]=min(f[v][0]+$\sum_{i=son[u]}^{i!=v} $min(f[i][0],f[i][1])),因为u要被至少一个子节点覆盖,所以可以暴力枚举每一个节点,则剩余的节点的状态与2相同。

也可以进行数学方法的优化,详见题解

P3174

有意思的一道题

定义:

f[n]表示当前节点的子树里最长的链中包含的点(主干与分叉)

ans[n]表示当前节点的答案

因为存在拐弯的情况,所以最优解不一定存在于根节点,因此转移方程:

    f[u]=max1+num;
    if(p==0) {//p是u的父亲,根节点无父亲
        ans[u]=max1+max2+num-1;
    } else ans[u]=max1+max2+num;

最终答案为max(ans[i])

P1352

模板题。 * 要用儿子更新父亲而不用父亲更新儿子的原因:会产生后效性:儿子的答案不仅与父亲有关,还与爷爷有关,爷爷不选,儿子也不能选。

树形背包类

P2015

树形背包模板题,注意因为是二叉树,所以转移方程有两种:

1.dp[u][t]=max(dp[u][t],dp[ln][t-k]+dp[rn][k]),ln rn 表示左儿子右儿子

2.dp[u][t]=max(dp[u][t],dp[v][k]+dp[u][t-k])

2的好处为可以应对多叉树。

原因是:

因为要先枚举t,所以就像背包一样,当前儿子的贡献能够加在之前的状态里,而之前的状态已经包含在深搜它之前的兄弟,以此类推。这样就可以进行累加。

v相当于石头种类;t相当于背包容量

注意要逆序枚举t,不然会对之前已经得到的状态进行二次累加

P1270

树形背包。

1.同样是dp[u][t]=max(dp[u][t],dp[v][k]+dp[u][t-k]);

2.dp[u][t]=max(dp[u][t],dp[u][t-j*5]+j) 在深搜完所有儿子后做这个

同样注意逆序枚举t

  • 结论:不一定要按照时间线性流逝设置状态,也可以将时间累加表示,即累加贡献。

P3360

同上,加入了价值。

P4362

每个树枝只有两种情况会产生难受值:

1) 两端点都被大头吃掉 2) 两端点都不被大头吃掉并且小头只有一个(m==2)

因为如果小头多余一个,那么任意一条满足情况2的树枝,它的两端都可以被不同的小头吃,就不会有贡献。

因此转移方程: \(\(f[u][i][0]=min(f[u][i][0],min(tmp[i-j][0]+f[v][j][0]+(m==2)* w,tmp[i-j][0]+f[v][j][1]));\\ f[u][i][1]=min(f[u][i][1],min(tmp[i-j][1]+f[v][j][1]+w,tmp[i-j][1]+f[v][j][0]));\)\)

注意m==2,就算没有难受值也要把子树的贡献加到背包里。

并且每次dp之前要将f[u]数组备份下来,这样就可以防止f[u]不断变小.

例如:

hydra

无tmp:

v1后:f[u][1][0]=0;

v2后:因为f[u][1][0]=0,所以f[u][1][0]=min(0,0+0+5)=0;

有tmp:

v1后:f[u][1][0]=0;

v2前:

f[u][1][0]=INF; tmp[1][0]=0;

v2后:f[u][1][0]=min(INF,0+0+5)=5;

因为f[u][1][0]在不同时刻表示的是不同状态,如果不备份,就会导致f[u]越来越小,即f[u][1][0]只能与当前时刻的新贡献比较,不能与tmp中之前的状态做min.tmp只能用来更新出当前j不同的状态.

换种说法,f[v2][1][0]表示当前状态,f[v1][1][0]表示之前的状态,则v2以及它之后的状态都会因为f[v1][1][0]=0而越变越小,产生了后效性。

为了不产生后效性,就不能让v2访问到f[v1][1][0],所以需要备份tmp。

#include<iostream>
#include<cstdio>
#include<cstring>

using namespace std;
const int N=530;
int f[N][N][2],head[N],sz[N],tmp[N][2];
int n,m,k,cnt;
struct node {
    int v,w,nxt;
} e[N<<1];
void add(int u,int v,int w) {
    e[++cnt].v=v;
    e[cnt].w=w;
    e[cnt].nxt=head[u];
    head[u]=cnt;
}
void dp(int u,int p) {
    f[u][0][0]=0;
    f[u][1][1]=0;
    sz[u]=1;
    for(int o=head[u]; ~o; o=e[o].nxt) {
        int v=e[o].v,w=e[o].w;
        if(v==p)continue;
        dp(v,u);
        sz[u]+=sz[v];
    }
    for(int o=head[u]; ~o; o=e[o].nxt) {
        int v=e[o].v,w=e[o].w;
        if(v==p)continue;
        memcpy(tmp,f[u],sizeof tmp);
        memset(f[u],63,sizeof f[u]);
        for(int i=min(k,sz[u]); i>=0; --i) {//注意这里取min(k,sz[u])可以快不少
            for(int j=min(i,sz[v]); j>=0; --j) {
                f[u][i][0]=min(f[u][i][0],min(tmp[i-j][0]+f[v][j][0]+(m==2)*w,tmp[i-j][0]+f[v][j][1]));

                f[u][i][1]=min(f[u][i][1],min(tmp[i-j][1]+f[v][j][1]+w,tmp[i-j][1]+f[v][j][0]));
            }
        }
    }
}
int main() {
    memset(head,-1,sizeof head);
    cnt=-1;
    scanf("%d%d%d",&n,&m,&k);
    if(n-k<m-1) {
        printf("-1");
        return 0;
    }
    for(int i=1,u,v,w; i<n; ++i) {
        scanf("%d%d%d",&u,&v,&w);
        add(u,v,w);
        add(v,u,w);
    }
    //init();
    memset(f,63,sizeof f);
    dp(1,1);
    printf("%d",f[1][k][1]);
    return 0;
}
* 树形背包的复杂度分析:

$$ 对于每个点,枚举的次数为sz[u] sz[v_1]+sz[u] sz[v_2]+...+sz[u]* sz[v_m]=O(sz[u]^2)\

(对于枚举到0,其实相差个常数,可以忽略。)\

所以整体的复杂度就是\sum_{u=1}^n sz[u]^2=O(n^3* k),k<1.\

事实上,最极限的情况就是一条链,此时k=\frac12

所以k<\frac12\

所以可以说树形背包的复杂度严格小于O(0.5n^3) $$

P3177

好题。

设f[u][i]表示每个节点的子树内所有的边产生的贡献。

注意状态的含义时贡献,不是每个子树内所有点的答案,因为子树内的点会和子树外的点产生额外的贡献,并且无法计算。

考虑每条边的贡献,实际上就是两边的同色点的乘积。

则转移方程:

    int tot=j*(k-j)+(sz[v]-j)*(n-sz[v]-k+j);
    f[u][i]=max(f[u][i],f[u][i-j]+f[v][j]+tot*w);
  • 注意: 1) 这个树形背包要清空非法状态(-1),不然会出错。

2) 第二层的枚举如果是倒序,就一定要先转移f[v][0],即j==0得情况,不然也会出错。原因就是,每次都会用f[u][j]更新一次:f[u][j]+f[v][0],这就不满足不用得出的状态更新其他状态了。

  • 结论: 当设计成子树内答案行不通,换句话说子树内答案受子树外节点影响时,可以设计成子树内边的贡献。
#include<iostream>
#include<cstdio>
#include<cstring>
#define int long long 
using namespace std;
const int N=2200;
int n,k,cnt;
int head[N],sz[N];
int f[N][N];
struct node{
    int v,nxt,w;
}e[N<<1];
void add(int u,int v,int w){
    e[++cnt].v=v;
    e[cnt].w=w;
    e[cnt].nxt=head[u];
    head[u]=cnt;
}
void dfs(int u,int p){
    sz[u]=1;
    for(int o=head[u];~o;o=e[o].nxt){
        int v=e[o].v;
        if(v==p)continue;
        dfs(v,u);
        sz[u]+=sz[v];
    }
}
void dp(int u,int p){
    f[u][1]=f[u][0]=0;
    for(int o=head[u];~o;o=e[o].nxt){
        int v=e[o].v,w=e[o].w;
        if(v==p)continue;
        dp(v,u);
        for(int i=min(sz[u],k);i>=0;--i){
            for(int j=0;j<=min(sz[v],i);++j){
                if(f[u][i-j]==-1)continue;
                int tot=j*(k-j)+(sz[v]-j)*(n-sz[v]-k+j);
                f[u][i]=max(f[u][i],f[u][i-j]+f[v][j]+tot*w);
            }
        } 
    }
}
int read1(){
    int x=0;
    char ch=getchar();
    while(ch>'9' || ch<'0'){
        ch=getchar();
    }
    while(ch<='9' && ch>='0'){
        x=(x<<1)+(x<<3)+ch-'0';
        ch=getchar();
    }
    return x;
}
void write1(int x){
    if(x>9)write1(x/10);
    putchar(x%10+'0');
}
signed main(){
    memset(head,-1,sizeof head);
    cnt=-1;
    n=read1(),k=read1();
    if(n-k<k)k=n-k;
    for(int i=1,u,v,w;i<n;++i){
        u=read1(),v=read1(),w=read1();
        add(u,v,w);
        add(v,u,w);
    }
    memset(f,-1,sizeof f);
    dfs(1,0);
    dp(1,0);
    write1(f[1][k]);
    return 0;
}

换根dp

P6419

\(f[u]\)表示子树内贡献,\(g[u]\)表示子树外贡献,\(dis[u][0/1]\)表示子树内最/次长链,\(up[u]\)表示子树外最长链。

注意更新\(up[u]\)\(g[u]\)时,如果所有关键点都包含在\(v\)子树内,说明\(u\)上面没有最长链,那么\(up[v]\)\(g[v]\)都不应该转移。(都为0)

#include<iostream>
#include<cstdio>
#include<cstring>
#define int long long 
using namespace std;
const int N=5e5+10;
int dis[N][2],up[N],head[N],f[N],g[N],sz[N],vis[N],ans[N];
int n,k,cnt,tot;
struct node {
    int v,w,nxt;
} e[N<<1];
void add(int u,int v,int w) {
    e[++cnt].v=v;
    e[cnt].w=w;
    e[cnt].nxt=head[u];
    head[u]=cnt;
}
void dfs(int u,int p) {
    sz[u]=vis[u];
    for(int i=head[u]; ~i; i=e[i].nxt) {
        int v=e[i].v,w=e[i].w;
        if(v==p)continue;

        dfs(v,u);
        sz[u]+=sz[v];

        if(!sz[v])continue;
        f[u]+=f[v]+w+w;

        if(dis[v][0]+w>dis[u][0]) {
            dis[u][1]=dis[u][0];
            dis[u][0]=dis[v][0]+w;
        } else if(dis[v][0]+w>dis[u][1]) {
            dis[u][1]=dis[v][0]+w;
        }

    }
}
void dp(int u,int p) {
    for(int i=head[u]; ~i; i=e[i].nxt) {
        int v=e[i].v,w=e[i].w;
        if(v==p)continue;
        if((k-sz[v])) {
            g[v]=g[u]+f[u]-f[v]+((sz[v]==0)?(w+w):0);
            if(dis[v][0]+w==dis[u][0]) {
                up[v]=max(up[u],dis[u][1])+w;
            } else {
                up[v]=max(up[u],dis[u][0])+w;
            }
        }
        dp(v,u);
    }
}
signed main() {
    memset(head,-1,sizeof head);
    cnt=-1;
    scanf("%lld%lld",&n,&k);
    for(int i=1,u,v,w; i<n; ++i) {
        scanf("%lld%lld%lld",&u,&v,&w);
        add(u,v,w);
        add(v,u,w);
    }
    for(int i=1,tmp; i<=k; ++i) {
        scanf("%lld",&tmp);
        vis[tmp]=1;
    }
    dfs(1,0);
    dp(1,0);
    for(int i=1; i<=n; ++i) {
        printf("%lld\n",f[i]+g[i]-max(dis[i][0],up[i]));
    }

    return 0;
}
  • 结论:

1) 换根dp的套路就是处理子树内的信息,子树外的信息由子树内信息相减得到。

2) 处理子树内信息时,顺序为\(v\to u\),即从叶子向上合并;处理子树外信息时,顺序为\(u \to v\),即从根向下推。

P3237

~~看题就看了30min~~

然而看懂了题也不知道怎么做。

题意简化:

给一棵树,每个点有一个权值,要求修改一些点的权值,使得:

1) 同一个父亲的儿子权值必须相同

2) 父亲的取值必须是所有儿子权值之和

一个重要结论:只要有一个点被确定了,那么整棵树的最终形态就随之确定了。

也就是说,一个点被确定后,只有与他"等价"的点不用改变,其他的都要变。

设最多的等价点个数总共有\(ans\)个,那么最终要改变\(n-ans\)个点。

下面要求出所有等价点。我们发现,两个节点等价需要满足:

\(\prod_{i=1}^{d_x-1}son[v[i]]* a[x]=\prod_{i=1}^{d_y-1}son[v[i]]* a[y]\)

\(f[i]\)表示上述式子,则:\(f[x]=f[y]\)

这样,我们就可以~~快乐地~~树形\(dp\)了.

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm> 
#define int long long 
using namespace std;
typedef double db;
const int N=5e5+10;
const db eps=1e-8,INF=1e8;
int n,u,v,cnt,tmp,ans;
int a[N],head[N];
db f[N];
struct node{
    int v,nxt;
}e[N<<1];
bool cmp(db a,db b){return a<b;}
void add(int u,int v){e[++cnt].v=v,e[cnt].nxt=head[u],head[u]=cnt;}
void dfs(int u,int p,db t){
    f[u]+=t;
    int tot=0;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p)continue;
        ++tot;
    }
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p)continue;
        dfs(v,u,t+log((db)(tot)));
    }
}
signed main(){
    memset(head,-1,sizeof head);cnt=-1;f[0]=-INF;
    scanf("%lld",&n);
    for(int i=1;i<=n;++i) scanf("%lld",&a[i]),f[i]=log((db)(a[i]));
    for(int i=1;i<n;++i) scanf("%lld%lld",&u,&v),add(u,v),add(v,u);
    dfs(1,0,0);
    sort(f+1,f+n+1,cmp);
    tmp=0,ans=0;
    for(int i=2;i<=n;++i){
        if(f[i]-f[i-1]<eps) ++tmp;
        else ans=max(ans,tmp),tmp=1;
    }
    ans=max(ans,tmp);
    printf("%lld",n-ans);
    return 0;
} 

P3523

问题可以转化为有用超过\(m\)个点来覆盖所有关键点,最大距离最小为\(ans\).

如果使用贪心+二分,则需要保证每个新节点管辖的范围最大并且与其他节点管辖的节点重叠最小。这样能保证用的点最少。

所以可以设\(f[i]\)表示在\(i\)的子树里距离\(i\)最远的关键节点;\(g[i]\)表示在\(i\)的子树里距离\(i\)最近的已选中的节点。

初始值\(f[u]=-\infty,g[u]=\infty\)

转移:

\[ f[u]=\max_{v\in son(u)}(f[v]+1),g[u]=\min_{v\in son(u)}(g[v]+1) \]

还有三个特判:

(\(k\)为当前二分的距离)

1) \(f[u]+g[u]<=k\),说明当前的子树可以通过这个选中节点被完全覆盖,所以\(f[u]=-\infty\),表示不对父节点的\(f\)产生影响

2) \(g[u]>k\),说明当前的子树上端不会被\(u\)的子孙完全覆盖,那么可以更新一下\(f[u]\)交给\(u\)的父亲处理,这样保证重叠最小。\(f[u]=\max(f[u],0)\),表示如果没有未覆盖关键节点在子树里就只算它自己。

3) \(f[u]==k\),说明当前节点恰好能用最小重叠覆盖整颗子树,并且因为这个更新是在回溯过程中的,要往上走,如果不选这个节点,后面就无法覆盖所有节点,所以强制选\(u\),\(f[u]=-\infty,g[u]=0,++tot\)

最后特判一下根节点。

#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=3e5+10,INF=1e9+10;
int f[N],g[N],head[N],a[N];
int cnt,n,m,u,v,tot;
struct node{
    int v,nxt;
}e[N<<1];
void add(int u,int v){
    e[++cnt].v=v,e[cnt].nxt=head[u],head[u]=cnt;
}
void dfs(int u,int p,int k){
    f[u]=-INF,g[u]=INF;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].v;
        if(v==p)continue;
        dfs(v,u,k);
        f[u]=max(f[u],f[v]+1);
        g[u]=min(g[u],g[v]+1);
    }
    if(f[u]+g[u]<=k) f[u]=-INF;
    if(g[u]>k && a[u]) f[u]=max(f[u],0);
    if(f[u]==k) f[u]=-INF,g[u]=0,++tot;
}
bool check(int k){
    tot=0;
    dfs(1,0,k);
    if(f[1]>=0) ++tot;
    return tot<=m;
}
int main(){
    memset(head,-1,sizeof head);cnt=-1;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;++i)scanf("%d",&a[i]);
    for(int i=1;i<n;++i)scanf("%d%d",&u,&v),add(u,v),add(v,u);
    int l=0,r=n,ans=INF;
    while(l<=r){
        int mid=l+r>>1;
        if(check(mid)) r=mid-1,ans=mid;     
        else l=mid+1;
    }
    printf("%d",ans);
    return 0;
}