题解 P4331 【[BOI2004]Sequence 数字序列】

· · 题解

标签:贪心+数据结构

求不下降序列$b$使得$\sum_{i=1}^n|a_i-b_i|$最小 我们可以感性的得到两个结论: ------------- 若: $$1.a_1\le a_2\le a_3...\le a_n $$ 则 $b_i=a_i 2.a_1\ge a_2\ge a_3...\ge a_n

b_i=a的中位数

根据这两个结论,我们可以得到一个比较感性的做法:

将原序列分成单调不升的m段,则每一段我们都取中位数。

然而这个做法不完全正确,比如我们取完中位数后是这样的序列:

3~3~3~2~2

我们发现这仍然不单调不降,所以我们需要合并这两个区间并重新取个中位数。

也就是说我们需要合并区间以求中位数。

Step2: 如何求中位数?

我们考虑维护一个堆,将整个区间的元素都存入堆内,每次从里面弹出最大元素,直到堆内只剩一半的元素为止

你可能会认为这个做法的正确性有问题,在我们合并区间的时候,新中位数是不是已经被删除了?

比如合并的区间是这样的:4~5~6~| ~7~8

本来的中位数是5,我们删除了6,但合并后中位数变成了6,在上述流程中是不能做到的,也就是错了?吗?

并不是,我们发现合并区间的前提条件是某个区间的中位数比它后面的区间中位数大,所以我们能保证,中位数一定大于前面这个区间的中位数而小于后面区间的中位数,换而言之,这样合并是正确的。

现在我们考虑实现,我们发现我们需要兹娃的操作只有弹出堆内最大值和合并堆,用左偏树可以做到O(n\log n)

End: 那么原题怎么做呢?

假设我们求出了答案,那么给每对a_i,b_i减去一个i答案并不会受到影响,而此时,我们只需要b数组单调不降即可,转换成了刚刚解决的问题。

代码:

#include<bits/stdc++.h>
using namespace std;
#define rep( i, s, t ) for( register int i = s; i <= t; ++ i )
#define re register
#define ls(x) ch[0][x]
#define rs(x) ch[1][x]
#define int long long
int read() {
    char cc = getchar(); int cn = 0, flus = 1;
    while(cc < '0' || cc > '9') {  if( cc == '-' ) flus = -flus;  cc = getchar();  }
    while(cc >= '0' && cc <= '9')  cn = cn * 10 + cc - '0', cc = getchar();
    return cn * flus;
}
const int N = 1e6 + 5 ; 
int n, top, Ans ;  
int a[N], b[N], ch[2][N], dis[N] ;
struct node {
    int rt, l, r, sz, val ; 
} s[N];
int merge( int x, int y ) {
    if( !x || !y ) return x + y ; 
    if( a[x] < a[y] ) swap( x, y ) ; 
    rs(x) = merge( rs(x), y ) ; 
    if( dis[rs(x)] > dis[ls(x)] ) swap( rs(x), ls(x) ) ;
    dis[x] = dis[rs(x)] + 1 ; return x ; 
}
int Del( int x ) { return merge( ls(x), rs(x) ) ; }
signed main()
{
    n = read() ; dis[0] = -1 ; 
    rep( i, 1, n ) a[i] = read() - i;
    rep( i, 1, n ) {
        s[++ top] = (node){ i, i, i, 1, a[i] } ; 
        while( top != 1 && s[top - 1].val > s[top].val ) {
            -- top ; s[top].rt = merge( s[top].rt, s[top + 1].rt ) ; 
            s[top].sz += s[top + 1].sz ; s[top].r = s[top + 1].r ; 
            while( s[top].sz > ( s[top].r - s[top].l + 2 ) / 2 ) { //此处可行性
            //合并的前提是左平均数小于右平均数,把左全部丢入右,平均数变大 
                -- s[top].sz, s[top].rt = Del( s[top].rt ) ; 
            }
            s[top].val = a[s[top].rt] ; 
        }
    }
    int cnt = 1 ;
    rep( i, 1, n ) {
        if( s[cnt].r < i ) ++ cnt ; 
        Ans += abs( s[cnt].val - a[i] ) ; 
    }
    printf("%lld\n", Ans ) ; cnt = 1 ; 
    rep( i, 1, n ) {
        if( s[cnt].r < i ) ++ cnt ;  
        printf("%lld ", s[cnt].val + i ) ; 
    }
    return 0;
}