题解 CF868F 【Yet Another Minimization Problem】
zhongyuwei · · 题解
我们设
而这个
如果
否则,这两个式子里面至少有一个取了小于符号。不妨假设是这样的:
移项可以得到:
也就是
接下来,我们可以利用决策单调性,将每一次转移的枚举量尽量减少。如果最先计算整个序列中间的那个点的转移,那么计算左右两半边的枚举量都可以减半。这样分治下去,每一层计算的复杂度是与整个序列的长度线性相关的。尽管对于
//这里l,r表示要计算[l,r]的dp值,[lb,rb]是决策点可能存在的区间
void solve(int lb,int rb,int l,int r)
{
if(lb>rb||l>r) return; int mid=l+r>>1;
int d=0; ll res=1e18;
for(int i=lb;i<=rb;++i)
{
ll tmp=cal(i+1,mid);
if(res>dp[cur-1][i]+tmp) res=dp[cur-1][i]+tmp,d=i;
}
dp[cur][mid]=res;
solve(lb,d,l,mid-1),solve(d,rb,mid+1,r);
}
还有一个问题需要解决:如何计算
int buc[N],L,R,a[N];
ll ans;
void update(int c,int d){ans+=d*buc[c]*(ll)(buc[c]-1)/2;}
ll cal(int l,int r)
{
while(L<l) update(a[L],-1),buc[a[L]]--,update(a[L],1),L++;
while(L>l) L--,update(a[L],-1),buc[a[L]]++,update(a[L],1);
while(R<r) R++,update(a[R],-1),buc[a[R]]++,update(a[R],1);
while(R>r) update(a[R],-1),buc[a[R]]--,update(a[R],1),R--;
return ans;
}
实际上,左右端点的总移动距离是
至此,这道题就在
完整代码:
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#define ll long long
using namespace std;
template <class T>
inline void read(T &x)
{
x=0; char c=getchar(); int f=1;
while(!isdigit(c)){if(c=='-')f=-1; c=getchar();}
while(isdigit(c)) x=x*10-'0'+c,c=getchar(); x*=f;
}
const int N=1e5+10;
int buc[N],L,R,a[N];
ll ans;
void update(int c,int d){ans+=d*buc[c]*(ll)(buc[c]-1)/2;}
ll cal(int l,int r)
{
while(L<l) update(a[L],-1),buc[a[L]]--,update(a[L],1),L++;
while(L>l) L--,update(a[L],-1),buc[a[L]]++,update(a[L],1);
while(R<r) R++,update(a[R],-1),buc[a[R]]++,update(a[R],1);
while(R>r) update(a[R],-1),buc[a[R]]--,update(a[R],1),R--;
return ans;
}
ll dp[22][N]; int cur;
void solve(int lb,int rb,int l,int r)
{
if(lb>rb||l>r) return; int mid=l+r>>1;
int d=0; ll res=1e18;
for(int i=lb;i<=rb;++i)
{
ll tmp=cal(i+1,mid);
if(res>dp[cur-1][i]+tmp) res=dp[cur-1][i]+tmp,d=i;
}
dp[cur][mid]=res;
solve(lb,d,l,mid-1),solve(d,rb,mid+1,r);
}
int main()
{
memset(dp,0x3f,sizeof(dp)); dp[0][0]=0;
int n,m; read(n),read(m);
for(int i=1;i<=n;++i) read(a[i]); buc[a[1]]++,L=R=1;
for(cur=1;cur<=m;++cur) solve(0,n-1,1,n);
printf("%lld",dp[m][n]);
return 0;
}