题解:P10842 【MX-J2-T3】Piggy and Trees

· · 题解

这里提供一种与官方题解不同的思路。

简要题意

计算一棵树上所有路径到所有点的距离之和。

题目分析

这是一个树上统计问题,而树上统计问题的一个经典思想就是对于每个点,先计算子树中的答案,再根据题目性质将答案合并。

我们设 ans_i 表示点 i 的子树内所有路径到所有点距离之和。首先 i 的所有子节点的子树已经被计算,直接加上它们的答案即可。然后分为两种情况,即 i 为端点和 i 不为端点。

对于第一种情况,我们先考虑任意一条路径。设路径深度较大的端点为 a,深度较小的端点为 b。显然,这条路径上的点到这条路径距离为 0a 子树中点到这条路径距离为到 a 的距离,b 子树外的点到这条路径距离为到 b 的距离。对于其它点,若它属于路径中某个点的子树,则它到这条路径的距离就为到这个点的距离。

我们枚举 i 的每个子节点 j,考虑另一路径端点 kj 子树内的情况,显然一共新增 size_j 条路径。(size_jj 子树的大小)对于每条路径,由上文得 j 子树外的所有点的贡献均为到 i 的距离,即 alldis_i-dis_j-size_j。(alldis_i 为所有点到 i 的距离之和,dis_jj 子树内点到 j 距离之和,再减去 size_j 是因为 j 子树内点到 i 的距离都要大 1,加起来就是它)而 j 子树内点的贡献又分为 k 子树中的点的贡献 dis_k 和路径上其它点的贡献。显然路径上每个点的贡献为它的子树中排除 k 所在子树的点到它的距离和。

现在考虑 O(1) 得到以上信息:显然 j 子树内每个点只会被计算 1 次,即只会有 size_j 条新路径。于是 size_j \times (alldis_i-dis_j-size_j) 即为第一种贡献,预处理 predis_i(即每个点子树内 dis 之和)表示第二种贡献。对于第三种,发现对于每个点 l 只有以 l 子树内(不包括 l)的点为路径端点,且所求点为 k 其他子树中的点时才会产生贡献,且这个贡献值与 i 无关,于是预处理 add_l 为这个值,preadd_ll 子树的 add_l 的和。add_l 的计算方式为:枚举 l 的每个子节点 v,求出 \sum size_v \times (dis_l-dis_v-size_v)。原因可以参考上面的解释。

对于第二种情况,我们发现路径两端点必为 i 的两个不同子节点子树中的点。每个子节点子树中的点都能与其他子节点子树中的点自由组合,一共有 d=\sum size_v \times (size_i-size_v-1) / 2 种情况。每种情况的贡献又分为三种:第一种为 i 子树外的点的贡献,第二种为路径两端点子树中点的贡献,第三种为 i 子树内其它点的贡献。

对于这种情况的第一种贡献,每条路径是等价的,为 d \times (alldis_i-dis_i) 种。对于第二种贡献,发现实际上就是把第一种情况的第二和第三中贡献再与 i 其它子节点中子树 (size_i-size_v-1) 个点自由组合了,将这两种贡献乘上 (size_i-size_v-1) 的系数即可。对于第三种贡献,考虑路径两端点在 i 的某两个子节点 nm 的子树,设 i 的子节点数量为 son_i,则贡献为(注:为了方便表示,以下 nm 表示 i 的第 nm 个子节点):

\begin{aligned} \sum_{n=1}^{son_i} \sum_{m=n+1}^{son_i} size_n \times size_m \times (dis_i-dis_n-size_n-dis_m-size_m) \end{aligned}

上式最坏会卡到 O(n^2)。因此设 powsize_i=\sum size_j \times (dis_j+size_j)rest=dis_i-dis_j-size_jji 的子节点),则上式转化为:

\begin{aligned} (\sum_{n=1}^{son_i} size_n \times (rest\times(size_i-size_n-1)-powsize_i+size_n\times(dis_n+size_n)))/2\end{aligned}

将所有贡献全部加起来,即为答案。时间复杂度 O(n)。关于贡献除以二后不为整数,alldissizedis 怎么求等细节请看代码。

#include<bits/stdc++.h>
#define int long long
#define endl '\n'
using namespace std;

const int N = 3e5+10;
const int M = 1e9+7;
int n,m,a[N];
int size[N],dis[N],pred[N]; // pred:题解中的 predis 
int powsz[N]; // 题解中的 powsize 
int addd[N],preadd[N]; //题解中的 add,preadd 
int alld[N],ans[N]; //题解中的 alldis 
vector<int> vec[N];

inline void dfs(int x,int fa){
    size[x] = 1;
    dis[x] = 0;
    for(auto v:vec[x]){
        if(v==fa) continue;
        dfs(v,x); 
        dis[x] += dis[v]+size[v]; 
        pred[x] += pred[v];
        size[x] += size[v];
        powsz[x] += size[v]*(dis[v]+size[v]);
        dis[x] %= M;
        pred[x] %= M;
        powsz[x] %= M;
    } 
    pred[x] += dis[x];
    pred[x] %= M;
}
inline void dfs2(int x,int fa){
    for(auto v:vec[x]){
        if(v==fa) continue;
        alld[v] = alld[x]-size[v]+(n-size[v])+M; // 从父节点开始往下推,考虑当点从父节点移到子节点时,子节点子树所有点到它距离-1,其它点距离+1 
        alld[v] %= M;
        dfs2(v,x);
    }
}
inline void dfs3(int x,int fa){
    for(auto v:vec[x]){
        if(v==fa) continue;
        dfs3(v,x);
        addd[x] += (size[v])*(dis[x]-dis[v]-size[v]+M+M)%M;
        preadd[x] += preadd[v];
        addd[x] %= M;
        preadd[x] %= M;
    }
    preadd[x] += addd[x];
    preadd[x] %= M;
}
inline void dfs4(int x,int fa){
    int temp = 0;
    for(auto v:vec[x]){
        if(v==fa) continue;
        dfs4(v,x);
        ans[x] += ans[v];
        ans[x] += (alld[x]-dis[v]-size[v]+M+M)*size[v]%M;
        ans[x] += pred[v];
        ans[x] += preadd[v];
        ans[x] += (pred[v]*(size[x]-size[v]-1+M))%M;
        ans[x] += preadd[v]*(size[x]-size[v]-1+M)%M;
        temp += (alld[x]-dis[x]+M)*(size[v]*(size[x]-size[v]-1)%M)%M; // temp是真实贡献的2倍,防止出现除以二后不为整数等问题 
        int rest = (dis[x]-dis[v]-size[v]+M+M)%M;
        temp += size[v]*((rest*(size[x]-size[v]-1+M)%M-powsz[x]+size[v]*(dis[v]+size[v])%M+M)%M)%M;
        ans[x] %= M;
        temp %= M;
    }
    temp = (temp*500000004)%M; // 注意temp是真实贡献的2倍,因此乘以2在模1e9+7意义下的逆元 
    ans[x] += temp;
    ans[x] %= M;
}

signed main(){
    ios::sync_with_stdio(false);
    cin>>n;
    int q1,q2;
    for(int i=1;i<=n-1;i++){
        cin>>q1>>q2;
        vec[q1].push_back(q2);
        vec[q2].push_back(q1);
    }
    dfs(1,0);
    alld[1] = dis[1];
    dfs2(1,0);
    dfs3(1,0);
    dfs4(1,0);
    cout<<ans[1]; 

    return 0;
}