有一种神奇的O(n)贪心做法,在CF上能过,然鹅是错的,随便卡卡就WA了。
以下算法比标算的O(n\sqrt n\log n)要优,证明来自YuHaoxiang。
有一个O(n^2)的dp:设f_{i,j}表示前i个数选j个数的最大值。
f_{i,j}=\max(f_{i-1,j},f_{i-1,j-1}+j\times a_i)
有一个结论:对于a_i,一定存在一个1\leqslant k\leqslant i使得任意0\leqslant j < k,有f_{i,j}=f_{i-1,j};任意k\leqslant j\leqslant i,有f_{i,j}=f_{i-1,j-1}+j\times a_i。
即存在一个分界点,使得左边的都不取a_i,右边的都取了a_i。
下面来证明这个结论。
当i=0,1时显然成立。
设F_{i,j}表示前i个数选j个数,最大方案选的下标集合,V(F_{i,j})表示取这个下标集合的值。
假设存在一个j,使得i\in F_{i,j}且i\notin F_{i,j+1}。
设F_{i-1,j}=F_{i-1,j-1}\cup\{x\},F_{i-1,j+1}=F_{i-1,j}\cup\{z\}=F_{i-1,j-1}\cup\{x,z\},F_{i,j}=F_{i-1,j-1}\cup\{i\}。
由假设得F_{i,j+1}=F_{i-1,j+1}=F_{i-1,j}\cup\{z\}。
令每种方案的值各不相同(相同可以通过扰动来避免),则f_{i-1,j}< f_{i,j}。
由此可得V(F_{i-1,j-1}\cup\{z\})< V(F_{i-1,j-1}\cup\{x\})。
- 若x<z,则V(F_{i-1,j-1}\cup\{z\})< V(F_{i-1,j-1}\cup\{x\})< V(F_{i-1,j-1}\cup\{i\})。由于x<z且x<i,所以V(F_{i-1,j-1}\cup\{x,z\})<V(F_{i-1,j-1}\cup\{x,i\})。
- 若x>z,则V(F_{i-1,j-1}\cup\{x\}))< V(F_{i-1,j-1}\cup\{i\})。由于z<x且z<i,所以V(F_{i-1,j-1}\cup\{z,x\})< V(F_{i-1,j-1}\cup\{z,i\})。又有V(F_{i-1,j-1}\cup\{z\})< V(F_{i-1,j-1}\cup\{x\}),由于z< i且x< i,所以V(F_{i-1,j-1}\cup\{z,i\})< V(F_{i-1,j-1}\cup\{x,i\})。所以V(F_{i-1,j-1}\cup\{z,x\})< V(F_{i-1,j-1}\cup\{x,i\})。
综上,V(F_{i-1,j-1}\cup\{z,x\})< V(F_{i-1,j-1}\cup\{x,i\})。
所以f_{i-1,j+1}< f_{i-1,j}+(j+1)\times a_i,即i\in F_{i,j+1},与假设矛盾。
由数学归纳法可知结论成立。
然后我们就可以二分出开始转移的位置,更新后面的值即可。
这样的时间复杂度还是$O(n^2)$的。
观察到每次转移,相当于在某处插入一个新值,然后给后面一部分加上一个等差数列。
所以用平衡树来实现即可。
时间复杂度$O(n\log^2 n)$。
## Code:
```cpp
#include<cstdio>
#include<algorithm>
typedef long long LL;
const int N=1e5+5;
int fa[N],ch[N][2],sz[N];LL A[N],B[N],val[N];
int n,a,rt,cnt;
inline int ckr(int x,const int p=1){return ch[fa[x]][p]==x;}
inline void update(int x){const int L=ch[x][0],R=ch[x][1];sz[x]=sz[L]+sz[R]+1;}
inline void modify(int x,LL a,LL b){val[x]+=a*(sz[ch[x][0]]+1)+b,A[x]+=a,B[x]+=b;}
inline void pushdown(int x){
if(A[x]||B[x]){
if(ch[x][0])modify(ch[x][0],A[x],B[x]);
if(ch[x][1])modify(ch[x][1],A[x],B[x]+A[x]*(sz[ch[x][0]]+1));
A[x]=B[x]=0;
}
}
inline void rotate(int x){
int y=fa[x],z=fa[y],k=ckr(x);
if(z)ch[z][ckr(y)]=x;
fa[x]=z,fa[y]=x,fa[ch[x][!k]]=y,ch[y][k]=ch[x][!k],ch[x][!k]=y,update(y);
}
inline void splay(int x){
static int sta[N],top;sta[top=1]=x;
for(int y=x;fa[y];sta[++top]=y=fa[y]);
while(top)pushdown(sta[top--]);
for(;fa[x];rotate(x))if(fa[fa[x]])rotate((ckr(x)^ckr(fa[x]))?x:fa[x]);update(rt=x);
}
inline LL getv(int k){
int x=rt;
while(1){
if(sz[ch[x][0]]+1==k){
splay(x);return val[x];
}
if(sz[ch[x][0]]>=k)x=ch[x][0];else k-=sz[ch[x][0]]+1,x=ch[x][1];
}
}
LL chkmax(int x){
if(!x)return -1e18;
pushdown(x);
return std::max(val[x],std::max(chkmax(ch[x][0]),chkmax(ch[x][1])));
}
int main(){
scanf("%d",&n);
sz[1]=rt=cnt=1;
for(int i=1,a;i<=n;++i){
scanf("%d",&a);
int l=0,r=i-2,ans=i-1;
while(l<=r){
const int mid=l+r>>1;
if(getv(mid+1)+(mid+1LL)*a>getv(mid+2))r=(ans=mid)-1;else l=mid+1;
}
getv(ans+1);
fa[++cnt]=rt,fa[ch[rt][1]]=cnt,ch[cnt][1]=ch[rt][1],ch[rt][1]=cnt,val[cnt]=val[rt];
modify(cnt,a,(LL)a*ans);
}
printf("%lld\n",chkmax(rt));
return 0;
}
```