[OI笔记]利用拉格朗日乘数法求函数的最值

s_r_f

2020-02-28 19:05:50

算法·理论

about

为什么写这篇Blog...

拉格朗日乘数法在今天训练的一道题上用到了,当场wyj/pcf/csl都正确的推出了式子.

但我却只会暴力DP.虽然也过了题但是多用了2k-3k的代码量.

但是赛后一看他们的1k左右的代码,人都傻了.

去网上搜了一下这种做法,自己推这题的时候偏导还求错了/kk,最后在pcf的提示下才发现/kk.

所以今天来补一补数学知识T _ T

正文

拉格朗日乘数法,

求一个多元函数z = F(x,y)在满足φ(x,y) = 0的条件极值.

($ 由于笔者比较菜$,$而且不等式约束在$OI$中基本不会用拉格朗日乘数法来解决 $($ 因为可以线性规划$,$但那就和本文的主题无关了 $)$ $,$所以本文中不讲不等式约束的解法$.$ $)

可以转化为,函数F(x,y,λ) = F(x,y) + λφ(x,y)的无条件极值.

无条件极值的求法?

F(x,y,λ)每个变量(x,y,λ)都求一次偏导,然后令求出来的式子=0即可.

至于正确性……

正确性 : λ求偏导就可以得到约束条件,然后λ的值不会影响函数值

具体的证明我不会,大概是隐函数什么的?

这里有两个链接以供参考: Link1 Link2

实际上这种方法只能求出极值,并不能知道求出来的是极小值还是极大值,只有算出来之后才会知道

理论的内容就这么多吧.

因为我不太会打\LaTeX的式子,所以就不放一些奇怪的很难打的式子,而用文字表述了.

批注:OI里的函数基本上都很平滑,所以乱搞也没关系的(

An Easy Example

求反比例函数y = \frac{1}{x}上距离原点最近的点.

即最小化F(x,y) = \sqrt{x^2 + y^2},满足约束φ(x,y) = xy-1 = 0

为了简化问题,我们把F(x,y) 变成 F^2(x,y),

这样问题要最小化的东西还是没变,而且F(x,y) = x^2+y^2方便求偏导.

F(x,y,λ) = F(x,y) + λφ(x,y)$ $ = x^2+y^2+λ(xy-1)

x求偏导得到: 2x + λy = 0

y求偏导得到: 2y + λx = 0

λ求偏导得到: xy - 1 = 0

最后解得 \begin{cases}x=1\\y=1\\λ=2\end{cases}\begin{cases}x=-1\\y=-1\\λ=2\end{cases}

A harder case: 一道OI

给你n(n \leq 8)个点,你可以任意放置这些点,但需要保证每个点离原点的距离为r_i.

求这些点组成的点集中,凸包的最大面积.

solution :

枚举凸包上的点集和点的排列方式.

现在我们令凸包上的点数为k,且这些点到原点的距离分别为r_1,r_2,...r_k

θ_i表示两条相邻的边r_ir_{i\mod k + 1}之间的夹角

由于三角形的面积S = \frac{1}{2}absin(θ),并且所有θ的和一定是2π,所以

答案即目标函数F = r_1r_2sin(θ_1)+r_2r_3sin(θ_2)+...+ r_kr_1sin(θ_k).

限制条件φ = 2π - \sum\limits_{i=1}^{k} θ_i.

F - λφ = \sum\limits_{i=1}^{k}r_ir_{i\mod k+1}sin(θ_i)$ $+$ $λ(2π -\sum\limits_{i=1}^{k}θ_i)

θ_i求偏导得: r_ir_{i\mod k+1}cos(θ_i)=λ

λ求偏导得: 2π -\sum\limits_{i=1}^{k}θ_i = 0

这个并不能让我们直接解出所有的θ_iλ.

注意到θ_i = arccos(\frac{λ}{r_ir_{i\mod k+1}}),所以在r_i已经确定的情况下, θ_i是随λ单调递增的.

所以我们可以二分λ并算出所有的θ_i,

进而求出\sum\limits_{i=1}^{k}θ_i=2π时候所有的θ_i的值,并计算答案.

那么这道题就做完了.

时间复杂度O(n!\times T),其中T为每次求最值时的二分次数.

Some Problem(s)

留作课后练习

[NOI2012] 骑行川藏

还是推式子

先留个坑,下次再补

```cpp #include <bits/stdc++.h> #define db long double using namespace std; const int N = 10050; int n; db E,s[N],k[N],v[N],x[N],ans = 1e6; inline bool check(db w){ db L,R,Mid,eps,sumE = 0,sum = 0; for (int i = 1; i <= n; ++i){ if (v[i]<0) L = 0; else L = v[i]; x[i] = -1,R = 1e18,eps = 1e-10; while (R - L >= eps){ Mid = (L+R)/2; if (k[i] * Mid * Mid * (Mid - v[i]) <= w) x[i] = Mid,L = Mid; else R = Mid; } sumE += s[i] * k[i] * (x[i] - v[i]) * (x[i] - v[i]); } if (sumE > E) return 0; for (int i = 1; i <= n; ++i) sum += s[i] / x[i]; if (sum < ans) ans = sum; return 1; } int main(){ int i; ios::sync_with_stdio(0); cin >> n >> E; for (i = 1; i <= n; ++i) cin >> s[i] >> k[i] >> v[i]; db L = 0,R = 1e18,Mid,eps = 1e-10; while (R - L >= eps){ Mid = (L+R)/2; if (check(Mid)) L = Mid; else R = Mid; } cout << fixed << setprecision(10) << ans << endl; return 0; } ``` ### [CF1344D Résumé Review](https://www.luogu.com.cn/problem/CF1344D) $F = \sum b_i(a_i-b_i^2) + λ(\sum b_i-k)

b_i求偏导得 λ=a_i-3b_i^2

然后考虑二分λ,但是因为有b_i\leq a_i 的限制,所以我们算出来的b_i可能不合法.

不过不合法的部分排序后显然是连续的一段,在外层再套一层二分即可.

然后求得了实数解,求整数解就是先令所有b_i=\lfloor b_i \rfloor 然后再排序贪心.

O(nlognlog(A_i)) code:
#include <bits/stdc++.h>
#define LL long long
#define LD long double
using namespace std;
template <typename T> void read(T &x){
    static char ch; x = 0,ch = getchar();
    while (!isdigit(ch)) ch = getchar();
    while (isdigit(ch)) x = x * 10 + ch - '0',ch = getchar();
}
inline void write(int x){if (x > 9) write(x/10); putchar(x%10+'0'); }
const int N = 100005;
LL n,k;
LL id[N],a[N]; LD b[N];
LL ans[N];
struct node{ int id; LL v; inline bool operator < (const node w) const{ return v < w.v; } }c[N]; 
int cnto;
inline int check(int p){
    LD L = -1e20,R = 1e20,Mid,kk = k,tot; int i; bool ok = 0;
    for (i = 1; i <= p; ++i) kk -= a[i];
    for (i = p+1; i <= n; ++i) L = max(L,(LD)a[i] - 3 * a[i] * a[i]),R = min(R,(LD)a[i]);
    while (R-L>1){
        Mid = (L+R)/2;
        tot = 0;
        for (i = p+1; i <= n; ++i) b[i] = sqrt((a[i]-Mid)/3),tot += b[i];
        if (tot >= kk) ok = 1,L = Mid; else R = Mid;
    }
    for (i = p+1; i <= n; ++i) b[i] = sqrt((a[i]-L)/3);
    return ok;
}
int main(){
    int i;
    read(n),read(k);
    for (i = 1; i <= n; ++i) read(c[i].v),c[i].id = i;
    sort(c+1,c+n+1);
    for (i = 1; i <= n; ++i) id[i] = c[i].id,a[i] = c[i].v;
    int L = 0,R = n,Mid,Ans = 0;
    while (L <= R){
        Mid = L+R>>1;
        for (i = 1; i <= Mid; ++i) b[i] = a[i];
        if (check(Mid)) Ans = Mid,R = Mid - 1;
        else L = Mid + 1;
    }
    for (i = 1; i <= Ans; ++i) b[i] = a[i];
    check(Ans);
    for (i = 1; i <= n; ++i) b[i] = floor(b[i]),k -= b[i];
    for (i = 1; i <= n; ++i) if (b[i] < a[i]){
        ++cnto;
        c[cnto].id = i,c[cnto].v = a[i] - 3 * b[i] * b[i] - 3 * b[i];
    }
    sort(c+1,c+cnto+1);
    reverse(c+1,c+cnto+1);
    for (i = 1; i <= k; ++i) b[c[i].id] += 1;
    for (i = 1; i <= n; ++i) ans[id[i]] = b[i];
    for (i = 1; i <= n; ++i) write(ans[i]),putchar(i<n?' ':'\n');
    return 0;
}