AT3957 01 on Tree 题解

· · 题解

很典型的一个贪心类型,这种题目通常会用一个叫做 Exchange Argument 的方法来构造正确的贪心策略。

atcoder 题目传送门 & 洛谷题目传送门

Description

有一个 n 个节点的树,每个节点有权值 v_i=0\lor v_i=1 ,求这颗树上节点的一个拓扑序(保证所有节点在它祖先之后),使得所有 v_i 组成的序列中逆序对数最小,输出最小逆序对个数。

Analysis

考虑没有拓扑序限制的情况下,最优解一定形如:0,0,...,1,1,1,...,所有 1 放在最后,逆序对数为 0。

我们希望答案的拓扑序尽可能的接近这个序列,所以有一个初步贪心的想法:

一旦可选节点中有 v_i=0,先选这个节点一定不会更劣。所以对于 v_i=0 的节点,一旦父亲被选了,就立即选它。

基于这个想法,我们可以将这个节点与父亲节点“合并”,因为它们一定是序列中相邻的。

接下来考虑合并后的节点如何维护他的权值。这里要用到 Exchange Argument 的套路。设两个节点(单个节点或多个合并的节点)分别为 ab,设 a 节点包含的所有小节点中有 a_0 个 0 和 a_1 个 1。同样地,b 包含的所有小节点中有 b_0 个 0 和 b_1 个 1。分别考虑先选 a 与先选 b 造成的逆序对个数谁更优,当且仅当:a_1\times b_0<b_1\times a_0 时,先选 a 更优。即:

\dfrac{a_1}{a_0}<\dfrac{b_1}{b_0}

所以新合并节点的权值就是 \dfrac{cnt_1}{cnt_0}

Solution

由于我们已经证明了这个贪心的正确性,接下来我们要做的,就是维护一个优先队列,每次取出权值最小的节点,将其与父节点合并。每次合并时统计对答案的贡献,直到所有节点合并为 1 个为止。在过程中就已经将最优答案求了出来。

具体实现上,用并查集去维护一下合并节点中的信息。注意权值\dfrac{cnt_1}{cnt_0} 中当 cnt_0=0 要取 +\infty

Code

#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define pdi pair<double,int>
#define mp make_pair
#define F first
#define S second
using namespace std;
const double INF=1e18;
int n,a[200005],fa[200005],p[200005];
bool vis[200005];
ll ans,cnt1[200005],cnt0[200005];
int find(int x){return (fa[x]==x?x:find(fa[x]));}
void merge(int x,int y)//并查集合并操作
{
    x=find(x),y=find(y);
    ans+=cnt1[y]*cnt0[x];//记录对逆序对总数的贡献
    cnt1[y]+=cnt1[x];//合并后加上0 1个数
    cnt0[y]+=cnt0[x];
    fa[x]=y;
}
int main()
{
    ios::sync_with_stdio(false),cin.tie(nullptr);
    cin>>n;
    p[0]=-1;
    for(int i=1;i<n;i++)
    {
        cin>>p[i];
        --p[i];
    }
    priority_queue<pdi,vector<pdi>,greater<pdi> > pq;//优先队列按权值升序存节点
    for(int i=0;i<n;i++)
    {
        fa[i]=i;
        cin>>a[i];
        if (a[i]) ++cnt1[i];
        else ++cnt0[i];
        pq.push(mp((cnt0[i]==0?INF:1.0*cnt1[i]/(1.0*cnt0[i])),i));
    }
    while(!pq.empty())
    {
        int u=pq.top().S;
        pq.pop();
        if (vis[u])
            continue;
        vis[u]=1;
        if (p[u]!=-1)//与父亲所在节点合并
        {
            int x=find(p[u]);
            merge(u,x);
            pq.push(mp((cnt0[x]==0?INF:1.0*cnt1[x]/(1.0*cnt0[x])),x));//将新节点加入优先队列
        }
    }
    cout<<ans<<endl;
    return 0;
}

希望题解对你有帮助!