题解 P6773【[NOI2020] 命运】

· · 题解

P6773 [NOI2020] 命运解题报告:

更好的阅读体验

题意

给定一个n个点的树,再给定一个路径集合Q满足路径端点存在祖先关系,定义一个关键边集合F是合法的当且进当路径集合Q中所有的路径都覆盖了至少一条关键边,求有多少个关键边集合F合法。(1\leqslant n\leqslant 5\times 10^5

分析

NOI2020D1T2,线段树合并优化dp经典题。

up_x表示x子树中的路径向上覆盖到最的深度(如果没有任何路径则up_x=0)。

这个“至少”容易让人联想到容斥,但纯容斥似乎无法推出多项式复杂度的做法,因此考虑dp。

考虑定义一个既包含当前节点也包含当前路径集合覆盖状态的的树形dp,那么设f_{x,y}表示x的子树中未覆盖关键边的路径覆盖到最的深度为y的方案数量。

y更新x时,我们可以写出状态转移方程(上面为选择(x,y)的情况,不难发现y为下端点的路径最深只能覆盖到dep_x;下面为不选择(x,y),讨论这个最深是x贡献还是y贡献):

f_{x,i}\leftarrow f_{x,i}\times\sum_{j=0}^{dep_x}f_{y,j}\\f_{x,i}\leftarrow f_{x,i}\times\sum_{j=0}^i f_{y,j}+f_{y,i}\times\sum_{j=0}^{i-1}f_{x,j}

很容易用前缀和优化来帮助转移,可以做到O(n^2),甚至可以用虚树来做到O(n\times \min\{m,\max\{dep_x\}\})

f_{x,i}\leftarrow f_{x,i}\times sumf_{y,dep_x}+f_{x,i}\times sumf_{y,i}+f_{y,i}\times sumf_{x,i-1}

看到这种前缀/后缀的转移方程,以及发现有很多状态是没有值的(叶子节点仅会有一个状态有值),那么我们可以想到线段树合并(类型相似的题:P5298 [PKUWC2018]Minimax)。

具体地,我们对于每个点开线段树维护dp状态,首先把f_{x,up_x}设为1,然后将x的所有儿子代表的线段树与x的线段树合并,合并的时候维护一下x的dp值前缀和与y的dp值前缀和就好了。

时间复杂度/空间复杂度:O(n\log n)(单点修改仅会有O(n)次)。

代码

#include<stdio.h>
#include<vector>
using namespace std;
const int maxn=500005,maxk=maxn*32,mod=998244353;
int n,m;
int up[maxn],dep[maxn];
int cnt;
int lc[maxk],rc[maxk],lazy[maxk],val[maxk],rt[maxn];
vector<int>g[maxn];
inline void pushup(int now){
    val[now]=(val[lc[now]]+val[rc[now]])%mod;
}
inline void getlazy(int now,int v){
    val[now]=1ll*val[now]*v%mod,lazy[now]=1ll*lazy[now]*v%mod;
}
inline void pushdown(int now){
    if(lazy[now]==1)
        return ;
    getlazy(lc[now],lazy[now]),getlazy(rc[now],lazy[now]);
    lazy[now]=1;
}
void update(int l,int r,int &now,int pos,int v){
    if(now==0)
        now=++cnt,lazy[now]=1;
    if(l==r){
        lazy[now]=1,val[now]=v;
        return ;
    }
    int mid=(l+r)>>1;
    pushdown(now);
    if(pos<=mid)
        update(l,mid,lc[now],pos,v);
    else update(mid+1,r,rc[now],pos,v);
    pushup(now);
}
int query(int l,int r,int now,int L,int R){
    if(L<=l&&r<=R)
        return val[now];
    int mid=(l+r)>>1,res=0;
    pushdown(now);
    if(L<=mid)
        res=(res+query(l,mid,lc[now],L,R))%mod;
    if(mid<R)
        res=(res+query(mid+1,r,rc[now],L,R))%mod;
    return res;
}
int merge(int l,int r,int a,int b,int &tag1,int &tag2){
    if(a==0&&b==0)
        return 0;
    if(a==0){
        tag2=(tag2+val[b])%mod,getlazy(b,tag1);
        return b;
    }
    if(b==0){
        tag1=(tag1+val[a])%mod,getlazy(a,tag2);
        return a;
    }
    if(l==r){
        int rec1=val[a],rec2=val[b];
        tag2=(tag2+rec2)%mod;
        val[a]=(1ll*val[a]*tag2%mod+1ll*val[b]*tag1%mod)%mod;
        tag1=(tag1+rec1)%mod;
        return a;
    }
    int mid=(l+r)>>1;
    pushdown(a),pushdown(b);
    lc[a]=merge(l,mid,lc[a],lc[b],tag1,tag2);
    rc[a]=merge(mid+1,r,rc[a],rc[b],tag1,tag2);
    pushup(a);
    return a;
}
void dfs(int x,int last){
    dep[x]=dep[last]+1;
    for(int i=0;i<g[x].size();i++){
        int y=g[x][i];
        if(y==last)
            continue;
        dfs(y,x);
    }
}
void getans(int x,int last){
    update(0,n,rt[x],up[x],1);
    for(int i=0;i<g[x].size();i++){
        int y=g[x][i];
        if(y==last)
            continue;
        getans(y,x);
        int tag1=0,tag2=query(0,n,rt[y],0,dep[x]);
        rt[x]=merge(0,n,rt[x],rt[y],tag1,tag2);
    }
}
int main(){
    scanf("%d",&n);
    for(int i=1;i<n;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        g[x].push_back(y),g[y].push_back(x);
    }
    dfs(1,0);
    scanf("%d",&m);
    for(int i=1;i<=m;i++){
        int x,y;
        scanf("%d%d",&x,&y);
        up[y]=max(up[y],dep[x]);
    }
    getans(1,0);
    printf("%d\n",query(0,n,rt[1],0,0));
    return 0;
}