题解 P3642 【[APIO2016]烟火表演】

· · 题解

首先有一个很显然的 dp 方程

f_{u,x} 为以 u 为根的子树中结束时间统一为 x 的最小代价,有 f_{u,x}=\sum \min_{y\le x}\{f_{v,y}+|w-x+y|\}

于是有了一个 \mathcal O(n(\sum w)^2) 的方程

接下来,我们考虑把 f_u 的所有点值写成函数的形式 f_u(x)

考虑 f_v(y) 对父亲 f_u(x) 的贡献

首先,我们发现这个函数每一段的斜率一定是单调递增的,并且斜率均为整数,于是 f_v(y) 的最优取值就是斜率为 0 的一段,记为 [L,R],接下来,记 f_v 对于 f_u(x) 的贡献函数为 F_v(x),记 u\to v 的边权为 w

根据 F_v 的定义,我们可以知道 F_v(x)=\min_{y\le x}\{f_v(y)+|w-x+y|\}

接下来,我们可以分三类来讨论一下:

  1. x 很小的时候,最优情况一定是把 w 修改为 0(由于修改 w 一定比下面每一个都修改优)
    具体来说就是 x<L 时,F_v(x)=f_v(x)+w 一定最优,因为 f_v(x)x<L 时为单调降函数,且斜率小于等于 -1,于是 w 少缩短 k,则 f_v 的增加量一定 \ge k,于是在 w 修改为 0 时取到最优解

  2. x 很大的时候,和上种情况类似,最优情况一定要把 w 拉长
    具体来说就是 x\ge R+w 时,F_v(x)=f_v(R)+x-R-w 一定最优,因为 f_v(x)x>R 时为单调升函数,且斜率大于等于 1,于是 w 少拉长 kf_v 的增加量一定 \ge k,又 f_v[L,R] 内都能取到最优值,所以 w 拉长到 x-R 时取到最优解

  3. x 适中的时候,我们可以调整 w 缩小的值使 F_v 取到最优,我们可以继续分两类讨论
    1)这种情况 w 不用修改即可取到最优值,具体来说, L+w\le x\le R+w 时, F_v(x)=f_v(x-w) 一定最优,因为 w 不修改,那么修改 w 的代价取到了最小,同时由于 x-w\in[L,R]f_v(x-w) 也取到了最小,那么这就是最优值
    2)这种情况 w 需要修改,并且一定是缩小(拉伸属于 2. 情况),具体来说,L\le x<L+w 时,F_v(x)=f_v(L)+w-x+L,原因与 1. 情况相同,只是这里没有必要把 w 修改为 0(原因是 w 多修改了以后 f_v 并不会因此减小)

于是我们可以把这几种情况整理一下,得:

F_v(x)= \begin{cases} f_v(x)+w & (x<L)\\ f_v(L)+w-x+L & (L\le x<L+w)\\ f_v(x-w)=f_v(L) &(L+w\le x\le R+w)\\ f_v(R)+x-R-w &(x>R+w) \end{cases}

考虑 F_v 对于 f_v 的变化关系

又容易发现 F_v 是连续的,于是我们在存储这段函数的时候只需要存储 F_v 的各个拐点即可

于是,我们修改 F_v 只需要把 LL 之后的所有拐点全部删除,新加入 L+wR+w 两个拐点即可

f_u 合并时,我们只需要把这些拐点序列直接和 f_u 本身拥有的拐点序列合并,并且不去重,因为重复 k 次的拐点可以表示这个点左右两边的斜率差为 k

又拐点横坐标单调下降,我们很容易可以想到用可并的大根堆来维护这个函数

接下来还有两个问题

  1. 如何从可并堆中找到斜率为 0 的区间的两个拐点 L,R
    我们发现每次修改的 F 函数中,斜率为正的拐点有且仅有一个(R+w),而一个函数 f_v 会合并儿子个数 k_v 次,也就是说 f_v 中有 k_v 个斜率为正的拐点,于是弹出 k_v-1 个,剩下的第一个就是 R,下一个就是 L

  2. 如何计算答案
    我们知道了根节点(1 号节点)的函数拐点集合,又知道了 f_1(0) 时的取值(所有边的权值和),又知道了一个拐点和下一个拐点对应的函数斜率差为 1,于是我们就可以轻松地算出答案了

可并堆用左偏树实现即可,复杂度 \mathcal O((n+m)\log (n+m))

另外有一个细节就是左偏树的节点要开两倍大小,因为每次把 f 修改成 F 都要新建两个节点

code:

#include<bits/stdc++.h>
#define MAXN 300010
#define int long long
using namespace std;
int n,m,ans;
int fa[MAXN],w[MAXN],deg[MAXN];
int val[MAXN<<1],d[MAXN<<1],rt[MAXN],tot;
int lc[MAXN<<1],rc[MAXN<<1];
int create(int v){
    val[++tot]=v;d[tot]=0;
    return tot;
}
int merge(int p,int q){
    if(!p||!q)return p+q;
    if(val[p]<val[q])swap(p,q);
    rc[p]=merge(rc[p],q);
    if(d[lc[p]]<d[rc[p]])swap(lc[p],rc[p]);
    d[p]=d[rc[p]]+1;return p;
}
void pop(int &p){p=merge(lc[p],rc[p]);}
signed main(){
    scanf("%lld%lld",&n,&m);
    for(int i=2;i<=n+m;i++){
        scanf("%lld%lld",&fa[i],&w[i]);
        ans+=w[i];deg[fa[i]]++;
    }
    for(int u=n+m;u>=2;u--){
        if(u<=n)while(deg[u]-->1)pop(rt[u]);
        int R=val[rt[u]];pop(rt[u]);
        int L=val[rt[u]];pop(rt[u]);
        rt[u]=merge(rt[u],merge(create(L+w[u]),create(R+w[u])));
        rt[fa[u]]=merge(rt[fa[u]],rt[u]);
    }
    while(deg[1]--)pop(rt[1]);
    while(rt[1])ans-=val[rt[1]],pop(rt[1]);
    printf("%lld",ans);
    return 0;
}