题解 P4383 【[八省联考2018]林克卡特树lct】

· · 题解

题意

给定一棵 n 个点的树,边权有正有负,要求在树上选出 k+1 条链,使得其权值之和最大。

看起来很不好做,那我们先考虑一下,对于60分的部分分如何处理。

我们可以这样设计状态:
f[i][j][0/1/2] 表示在 i 节点的子树内,已经有 j完整的链,当前 i 节点的度数为 0/1/2 的最大价值。度数为 0 时,这个点没有连边。度数为 1 时,这个点拖着一条未完成的链,而这条链不计入 j 。度数为 2 时,这个点被一条连接两个不同子树的链穿过。

可以分类讨论出如下转移:

首先,我们约定在每个节点的全部转移结束时,进行一次更新:

f[i][j][0]=\max\{f[i][j][0],f[i][j-1][1],f[i][j][2]\}

这样,我们就把 i 节点的全部最优解统计了出来。对于度数为 0/2 的情况可以直接合并,而对于度数为 1 的情况,要先在 i 节点处把当前 i 节点拖着的一条链断掉,然后将这条链计入总数,合并。

f[i][j][0]=\max_{l=1}^{j-1} \{ f[i][j][0],f[i][l][0]+f[x][j-l][0]\}

显然,如果当前节点的度数为 0 ,肯定不能连边,只能取子节点的最优解。

f[i][j][1]=\max_{l=1}^{j-1} \{ f[i][j][1],f[i][l][1]+f[x][j-l][0],f[i][l][0]+f[x][j-l][1]+w(i,x)\}

如果要求度数为 1 ,可以继承之前的结果,不连边。同样,可以连上这条边,继承子节点拖着的一条未完成的链,同时取这条边的权值。

f[i][j][2]=\max_{l=1}^{j-1} \{ f[i][j][2],f[i][l][2]+f[x][j-l][0],f[i][l][1]+f[x][j-l-1][1]+w(i,x)\}

要求度数为 2 ,同样可以继承之前的结果,取子节点的最优解。也可以将之前 i 节点挂着的链与子节点挂着的链连接起来,就新完成了一条链。

f[i][j][1]=\max\{f[i][j][1],f[i][j][0]+f[x][0][1]+w(i,x)\} f[i][0][1]=\max\{f[i][0][1],f[i][0][0]+f[x][0][1]+w(i,x)\}

处理从子节点出发的链。

这样我们可以得到一个 O(nk^2) 的dp做法,可以拿到45(35)分。

考虑如何优化。我们注意到,每次增加一条链,得到的收益是单调不增的。
每次增加一条链有两种做法,一种是新选一条链,一种是将一条链分成两条。每种操作的收益是一定的,因此每次都会选择最大收益的操作。而由于边权不会改变,后续不可能有操作会有更大的收益。因此,我们可以口胡出这是一个上凸的函数。

f[1][x][0]F(x) ,我们知道 F(x) 是一个上凸函数。并且,恒有 F(0)=F(n)=0 。因此,我们可以考虑用一条斜率一定的直线去切这个凸壳。设这条直线为 l:y=ax+b ,根据上凸函数的特性,斜率一定时,切线的截距一定大于割线。

我们令切点为 (t,F(t)) ,则切线满足 a*t+b=F(t) ,即 b=F(t) - a*t 。可以发现, b 的表达式就相当于给每个物品赋上额外的权值 -t 时,dp所得的最优解。这样,我们只需要一次朴素的dp就可以求出截距。

而如何获取斜率呢?再一次,由于函数上凸,我们可以发现当斜率不断增大时,对应的切点横坐标也在不断左移。这样,我们就可以二分求相应的斜率了。每次check的时候顺便记录一下最优解选取了多少个,从而判断应该如何调整二分区间。

还有一个问题,有可能出现凸壳上多个点共线的情况。为了处理这种情况,我们可以限定在dp过程中,权值相同时优先取选择次数更小的转移。这样,我们求出来的次数就是当前切点的左边界,二分时判断一下就可以了。

回到这题,套用60(45)分做法的dp方式,去掉有关次数的限制,每次dp的复杂度就变成 O(n) 。总复杂度为 O(n \log k)

45分代码

#include<bits/stdc++.h>
#define reg register
typedef long long ll;
using namespace std;
const int MN=3e5+5;
int to[MN<<1],nxt[MN<<1],c[MN<<1],h[MN],cnt;
inline void ins(int s,int t,int w){
    to[++cnt]=t;nxt[cnt]=h[s];c[cnt]=w;h[s]=cnt;
    to[++cnt]=s;nxt[cnt]=h[t];c[cnt]=w;h[t]=cnt;
}
#define chkmax(a,b) ((a)<(b)?(a)=(b):0)
int n,K;
int f[MN][105][3];
void dfs(int st,int fa=0){
    f[st][0][0]=f[st][0][1]=f[st][1][2]=0;
    for(reg int i=h[st];i;i=nxt[i]){
        if(to[i]==fa)continue;
        dfs(to[i],st);
        for(reg int j=K;j;j--){
            chkmax(f[st][j][1],f[st][j][0]+f[to[i]][0][1]+c[i]);
            for(reg int k=j-1;~k;k--){
                chkmax(f[st][j][0],f[st][k][0]+f[to[i]][j-k][0]);
                chkmax(f[st][j][1],max(f[st][k][1]+f[to[i]][j-k][0],f[st][k][0]+f[to[i]][j-k][1]+c[i]));
                chkmax(f[st][j][2],max(f[st][k][2]+f[to[i]][j-k][0],f[st][k][1]+f[to[i]][j-k-1][1]+c[i]));
            }
        }
        chkmax(f[st][0][1],f[to[i]][0][1]+c[i]);
    }
    for(reg int i=1;i<=K;i++)
        chkmax(f[st][i][0],max(f[st][i-1][1],f[st][i][2]));
}
int main(){
    scanf("%d%d",&n,&K);K++;
    for(reg int i=1,s,t,v;i<n;i++)
        scanf("%d%d%d",&s,&t,&v),ins(s,t,v);
    memset(f,~0x3f,sizeof(f));dfs(1);
    printf("%d\n",f[1][K][0]);
    return 0;
}

100分代码

#include<bits/stdc++.h>
#define reg register
typedef long long ll;
using namespace std;
const int MN=3e5+5;
int to[MN<<1],nxt[MN<<1],c[MN<<1],h[MN],cnt;
inline void ins(int s,int t,int w){
    to[++cnt]=t;nxt[cnt]=h[s];c[cnt]=w;h[s]=cnt;
    to[++cnt]=s;nxt[cnt]=h[t];c[cnt]=w;h[t]=cnt;
}
#define chkmax(a,b) ((a)<(b)?(a)=(b):0)
int n,k;
ll l,r,mid;
struct data{
    ll val;int pos;
    data(ll x=0,int y=0):val(x),pos(y){}
    friend bool operator<(data a,data b){
        return a.val==b.val?a.pos>b.pos:a.val<b.val;
    }
    friend data operator+(data a,data b){
        return data(a.val+b.val,a.pos+b.pos);
    }
    friend data operator+(data a,ll b){
        return data(a.val+b,a.pos);
    }
}f[MN][3],tmp;
int fa[MN];
void getf(int st){
    for(reg int i=h[st];i;i=nxt[i])
        if(to[i]!=fa[st])fa[to[i]]=st,getf(to[i]);
}
void dfs(int st){
    f[st][0]=f[st][1]=f[st][2]=data();
    chkmax(f[st][2],tmp);
    for(reg int i=h[st];i;i=nxt[i]){
        if(to[i]==fa[st])continue;dfs(to[i]);
        chkmax(f[st][2],max(f[st][2]+f[to[i]][0],f[st][1]+f[to[i]][1]+c[i]+tmp));
        chkmax(f[st][1],max(f[st][1]+f[to[i]][0],f[st][0]+f[to[i]][1]+c[i]));
        chkmax(f[st][0],f[st][0]+f[to[i]][0]);
    }
    chkmax(f[st][0],max(f[st][1]+tmp,f[st][2]));
}
int main(){
    scanf("%d%d",&n,&k);k++;
    for(reg int i=1,s,t,w;i<n;i++)
        scanf("%d%d%d",&s,&t,&w),ins(s,t,w);
    l=-1e12;r=1e12;getf(1);
    while(l<r){
        mid=(double)(l+r)/2-0.5;
        tmp=data(-mid,1);dfs(1);
        if(f[1][0].pos==k){
            printf("%lld\n",f[1][0].val+mid*k);
            return 0;
        }
        if(f[1][0].pos>k)l=mid+1;
        else r=mid;
    }
    mid=l;tmp=data(-mid,1);dfs(1);
    printf("%lld\n",f[1][0].val+mid*k);
    return 0;
}