题解 P4516 【[JSOI2018]潜入行动】

· · 题解

这是一个并不简单的背包类树形dp……

很自然地想到状态定义:f[u][k][0/1][0/1]表示以u为根的子树中,总共选择k个结点,其中除了u以外的所有结点均被监听到,u结点选或不选,u结点是否被覆盖的情况下,一共有多少种方案。

状态转移看似十分麻烦。每个结点u都有许多子结点,很难统计出每个子结点的所有情况(似乎在组合数学的范畴)。但是我们可以用十分巧妙的树形背包来进行状态转移。树上背包的转移套路是:

f[u][i+j]=combine(f[u][i],f[v][j])

相当于每递归访问完一个子结点,就把子节点上的状态与当前已经处理的状态一一配对,保证不重不漏且兼顾效率。具体的转移方程为:

f[u][i+j][0][0]=\sum f[u][i][0][0]*f[v][j][0][1] f[u][i+j][1][0]=\sum f[u][i][1][0]*(f[v][j][0][0]+f[v][j][0][1]) f[u][i+j][0][1]=\sum (f[u][i][0][1]\*(f[v][j][0][1]+f[v][j][1][1])+f[u][i][0][0]\*f[v][j][1][1] f[u][i+j][1][1]=\sum (f[u][i][1][0]\*(f[v][j][1][0]+f[v][j][1][1])+f[u][i][1][1]\*(f[v][j][0][0]+f[v][j][0][1]+f[v][j][1][0]+f[v][j][1][1]))

具体实现时还应注意:因为阶段(即扫描子结点个数)的划分,在每次转移前都要先记录原始的u结点上的数据,否则会导致混乱。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const int MAXN=100005;
const int mod=1000000007;

int N,K;
int f[MAXN][105][2][2];
int g[105][2][2];
int size[MAXN];

vector<int> G[MAXN];

inline int Mod(ll x,ll y){
    x%=mod,y%=mod;
    return (int)(x+y)%mod;
}

void dp(int u,int fa){
    size[u]=1;
    f[u][0][0][0]=f[u][1][1][0]=1;
    for(vector<int>::iterator it=G[u].begin();it!=G[u].end();it++){
        int v=*it;
        if(v==fa) continue;
        dp(v,u);
        for(register int i=0;i<=min(size[u],K);++i){
            g[i][0][0]=f[u][i][0][0],f[u][i][0][0]=0;
            g[i][0][1]=f[u][i][0][1],f[u][i][0][1]=0;
            g[i][1][0]=f[u][i][1][0],f[u][i][1][0]=0;
            g[i][1][1]=f[u][i][1][1],f[u][i][1][1]=0;
        }
        for(register int i=0;i<=min(size[u],K);++i){
            for(register int j=0;j<=min(size[v],K-i);++j){
                f[u][i+j][0][0]=Mod((ll)f[u][i+j][0][0],(ll)g[i][0][0]*(ll)f[v][j][0][1]);
                f[u][i+j][0][1]=Mod((ll)f[u][i+j][0][1],(ll)g[i][0][0]*(ll)f[v][j][1][1]+(ll)g[i][0][1]*((ll)f[v][j][1][1]+(ll)f[v][j][0][1]));
                f[u][i+j][1][0]=Mod((ll)f[u][i+j][1][0],(ll)g[i][1][0]*((ll)f[v][j][0][0]+(ll)f[v][j][0][1]));
                f[u][i+j][1][1]=Mod((ll)f[u][i+j][1][1],(ll)g[i][1][0]*((ll)f[v][j][1][0]+(ll)f[v][j][1][1])+(ll)g[i][1][1]*((ll)f[v][j][0][0]+(ll)f[v][j][0][1]+(ll)f[v][j][1][0]+(ll)f[v][j][1][1]));
            }
        }
        size[u]+=size[v];
    }
}

int main(){
    scanf("%d%d",&N,&K);
    for(register int i=1;i<N;++i){
        int u,v;
        scanf("%d%d",&u,&v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dp(1,0);
    printf("%d",(int)(f[1][K][0][1]+f[1][K][1][1])%mod);
    return 0;
}