题解 P4383 【[八省联考2018]林克卡特树lct】
题意
给定一棵
看起来很不好做,那我们先考虑一下,对于60分的部分分如何处理。
我们可以这样设计状态:
令
可以分类讨论出如下转移:
首先,我们约定在每个节点的全部转移结束时,进行一次更新:
这样,我们就把
显然,如果当前节点的度数为
如果要求度数为
要求度数为
处理从子节点出发的链。
这样我们可以得到一个
考虑如何优化。我们注意到,每次增加一条链,得到的收益是单调不增的。
每次增加一条链有两种做法,一种是新选一条链,一种是将一条链分成两条。每种操作的收益是一定的,因此每次都会选择最大收益的操作。而由于边权不会改变,后续不可能有操作会有更大的收益。因此,我们可以口胡出这是一个上凸的函数。
记
我们令切点为
而如何获取斜率呢?再一次,由于函数上凸,我们可以发现当斜率不断增大时,对应的切点横坐标也在不断左移。这样,我们就可以二分求相应的斜率了。每次check的时候顺便记录一下最优解选取了多少个,从而判断应该如何调整二分区间。
还有一个问题,有可能出现凸壳上多个点共线的情况。为了处理这种情况,我们可以限定在dp过程中,权值相同时优先取选择次数更小的转移。这样,我们求出来的次数就是当前切点的左边界,二分时判断一下就可以了。
回到这题,套用60(45)分做法的dp方式,去掉有关次数的限制,每次dp的复杂度就变成
45分代码
#include<bits/stdc++.h>
#define reg register
typedef long long ll;
using namespace std;
const int MN=3e5+5;
int to[MN<<1],nxt[MN<<1],c[MN<<1],h[MN],cnt;
inline void ins(int s,int t,int w){
to[++cnt]=t;nxt[cnt]=h[s];c[cnt]=w;h[s]=cnt;
to[++cnt]=s;nxt[cnt]=h[t];c[cnt]=w;h[t]=cnt;
}
#define chkmax(a,b) ((a)<(b)?(a)=(b):0)
int n,K;
int f[MN][105][3];
void dfs(int st,int fa=0){
f[st][0][0]=f[st][0][1]=f[st][1][2]=0;
for(reg int i=h[st];i;i=nxt[i]){
if(to[i]==fa)continue;
dfs(to[i],st);
for(reg int j=K;j;j--){
chkmax(f[st][j][1],f[st][j][0]+f[to[i]][0][1]+c[i]);
for(reg int k=j-1;~k;k--){
chkmax(f[st][j][0],f[st][k][0]+f[to[i]][j-k][0]);
chkmax(f[st][j][1],max(f[st][k][1]+f[to[i]][j-k][0],f[st][k][0]+f[to[i]][j-k][1]+c[i]));
chkmax(f[st][j][2],max(f[st][k][2]+f[to[i]][j-k][0],f[st][k][1]+f[to[i]][j-k-1][1]+c[i]));
}
}
chkmax(f[st][0][1],f[to[i]][0][1]+c[i]);
}
for(reg int i=1;i<=K;i++)
chkmax(f[st][i][0],max(f[st][i-1][1],f[st][i][2]));
}
int main(){
scanf("%d%d",&n,&K);K++;
for(reg int i=1,s,t,v;i<n;i++)
scanf("%d%d%d",&s,&t,&v),ins(s,t,v);
memset(f,~0x3f,sizeof(f));dfs(1);
printf("%d\n",f[1][K][0]);
return 0;
}
100分代码
#include<bits/stdc++.h>
#define reg register
typedef long long ll;
using namespace std;
const int MN=3e5+5;
int to[MN<<1],nxt[MN<<1],c[MN<<1],h[MN],cnt;
inline void ins(int s,int t,int w){
to[++cnt]=t;nxt[cnt]=h[s];c[cnt]=w;h[s]=cnt;
to[++cnt]=s;nxt[cnt]=h[t];c[cnt]=w;h[t]=cnt;
}
#define chkmax(a,b) ((a)<(b)?(a)=(b):0)
int n,k;
ll l,r,mid;
struct data{
ll val;int pos;
data(ll x=0,int y=0):val(x),pos(y){}
friend bool operator<(data a,data b){
return a.val==b.val?a.pos>b.pos:a.val<b.val;
}
friend data operator+(data a,data b){
return data(a.val+b.val,a.pos+b.pos);
}
friend data operator+(data a,ll b){
return data(a.val+b,a.pos);
}
}f[MN][3],tmp;
int fa[MN];
void getf(int st){
for(reg int i=h[st];i;i=nxt[i])
if(to[i]!=fa[st])fa[to[i]]=st,getf(to[i]);
}
void dfs(int st){
f[st][0]=f[st][1]=f[st][2]=data();
chkmax(f[st][2],tmp);
for(reg int i=h[st];i;i=nxt[i]){
if(to[i]==fa[st])continue;dfs(to[i]);
chkmax(f[st][2],max(f[st][2]+f[to[i]][0],f[st][1]+f[to[i]][1]+c[i]+tmp));
chkmax(f[st][1],max(f[st][1]+f[to[i]][0],f[st][0]+f[to[i]][1]+c[i]));
chkmax(f[st][0],f[st][0]+f[to[i]][0]);
}
chkmax(f[st][0],max(f[st][1]+tmp,f[st][2]));
}
int main(){
scanf("%d%d",&n,&k);k++;
for(reg int i=1,s,t,w;i<n;i++)
scanf("%d%d%d",&s,&t,&w),ins(s,t,w);
l=-1e12;r=1e12;getf(1);
while(l<r){
mid=(double)(l+r)/2-0.5;
tmp=data(-mid,1);dfs(1);
if(f[1][0].pos==k){
printf("%lld\n",f[1][0].val+mid*k);
return 0;
}
if(f[1][0].pos>k)l=mid+1;
else r=mid;
}
mid=l;tmp=data(-mid,1);dfs(1);
printf("%lld\n",f[1][0].val+mid*k);
return 0;
}