P4253 [SCOI2015]小凸玩密室
Captain_Paul · · 题解
神奇的树形dp系列
先看一下题目给出的条件:
1、完全二叉树
2、任意时刻点亮的灯泡必须连通
3、点亮一个灯泡后必须先点亮其子树
那么点亮灯泡的过程就类似这样:
点亮灯泡k,点亮它的一个子树,再点亮它另外的子树,
然后回到k的父节点,点亮fa之后再点亮fa的其他子树……
所以对于一个节点u,有这样两种情况:
u还没有被点亮,则下一个被点亮的一定是它的儿子
u是叶子节点,在下一个被点亮的一定是它的某一级祖先,或者是它某一级祖先的儿子
我们定义dp数组f和g
f[i][j]表示点亮i之后回到i的第j个祖先的最小花费
g[i][j]表示点亮i之后回到i的第j个祖先的另一个儿子的最小花费
然后从下到上,由儿子的状态转移到父亲的状态
注意讨论当前节点的儿子个数。
最后统计答案时,根据点亮的过程累加即可。
ps:由于这是一棵完全二叉树,所以可以不用递归的方式dfs
直接预处理出每个节点的儿子和它到各级祖先的距离,用循环转移即可。
此题数据贼大,所以最终答案ans的初值一定要开大!
丑陋的代码:
#include<cstdio>
#include<cstring>
#include<cctype>
#include<algorithm>
#define reg register
using namespace std;
typedef long long ll;
const int N=2e5+5;
int n,num,w[N],fa[N],ls[N],rs[N];
ll f[N][20];//f表示i是亮的,回到i的第j个祖先的最小花费
ll g[N][20];//g表示i是亮的,回到i的第j个祖先的另一个儿子的最小花费
ll dis[N][20];//dis表示从i到i的第j个祖先的距离
ll ans=1e17;
inline int read()
{
int x=0,w=1;
char c=getchar();
while (!isdigit(c)&&c!='-') c=getchar();
if (c=='-') c=getchar(),w=-1;
while (isdigit(c))
{
x=(x<<1)+(x<<3)+c-'0';
c=getchar();
}
return x*w;
}
inline int brother(int k,int x)//k的第x个祖先的另一个儿子
{
return (k>>(x-1))^1;
}
inline ll getans()
{
for (reg int k=n;k>=1;k--)
{
if (!ls[k])//k为叶子节点
for (reg int i=1;k>>(i-1);i++)
g[k][i]=(dis[k][i]+dis[brother(k,i)][1])*w[brother(k,i)];
else if (!rs[k])//k只有左儿子
for (reg int i=1;k>>(i-1);i++)
g[k][i]=dis[ls[k]][1]*w[ls[k]]+g[ls[k]][i+1];
else//k有两个儿子
for (reg int i=1;k>>(i-1);i++)
g[k][i]=min(dis[ls[k]][1]*w[ls[k]]+g[ls[k]][1]+g[rs[k]][i+1],dis[rs[k]][1]*w[rs[k]]+g[rs[k]][1]+g[ls[k]][i+1]);
}
for (reg int k=n;k>=1;k--)
{
if(!ls[k])
for (reg int i=1;k>>(i-1);i++)
f[k][i]=dis[k][i]*w[k>>i];
else if (!rs[k])
for (reg int i=1;k>>(i-1);i++)
f[k][i]=f[ls[k]][i+1]+dis[ls[k]][1]*w[ls[k]];
else
for (reg int i=1;k>>(i-1);i++)
f[k][i]=min(dis[ls[k]][1]*w[ls[k]]+g[ls[k]][1]+f[rs[k]][i+1],dis[rs[k]][1]*w[rs[k]]+g[rs[k]][1]+f[ls[k]][i+1]);
}
for (reg int k=1;k<=n;k++)
{
reg ll sum=f[k][1];
for (reg int i=1,fa=k>>1;fa;++i,fa>>=1)
{
reg int bro=brother(k,i);
if (bro>n) sum+=dis[fa][1]*w[fa>>1];
else sum+=dis[bro][1]*w[bro]+f[bro][2];
}
ans=min(ans,sum);
}
return ans;
}
int main()
{
n=read();
for (reg int i=1;i<=n;w[i++]=read());
for (reg int i=2;i<=n;i++) dis[i][1]=(ll)read();
for (reg int i=1;i<=(n>>1)+1;i++)//完全二叉树
{
if ((i<<1)<=n) ls[i]=(i<<1);
else break;
if ((i<<1|1)<=n) rs[i]=(i<<1|1);
}
for (reg int i=2;i<=18;i++)
for (reg int k=n;k>>i;k--)
dis[k][i]=dis[k][i-1]+dis[k>>(i-1)][1];
printf("%lld\n",getans());
return 0;
}