别光背公式了!用Python和NumPy动手验证Jensen不等式(附代码)

发布时间:2026/6/7 1:47:00

别光背公式了!用Python和NumPy动手验证Jensen不等式(附代码) 别光背公式了用Python和NumPy动手验证Jensen不等式附代码数学公式如果只停留在纸面上往往会让人感到抽象难懂。Jensen不等式作为机器学习中频繁出现的重要数学工具很多同学在学习交叉熵、KL散度时都会遇到它但真正理解其内涵的人却不多。今天我们就用Python和NumPy通过代码实验和可视化带你直观感受这个不等式的威力。1. 准备工作理解Jensen不等式的核心Jensen不等式描述的是凸函数的一个基本性质对于凸函数f和任意一组点x₁,x₂,...,xₙ以及满足∑λᵢ1的非负权重λᵢ有f(∑λᵢxᵢ) ≤ ∑λᵢf(xᵢ)这个看似简单的式子在实际应用中却有着深远的意义。比如在机器学习中交叉熵损失函数的凸性保证了优化过程的稳定性EM算法中下界的构造依赖于Jensen不等式信息论中许多重要结论都建立在这个不等式基础上我们先来设置Python环境import numpy as np import matplotlib.pyplot as plt from scipy.stats import norm2. 验证基础案例二次函数让我们从一个最简单的凸函数开始f(x) x²。这是一个典型的凸函数非常适合用来验证Jensen不等式。实验设计随机生成一组x值随机生成对应的权重λ计算不等式两边的值比较结果def quadratic(x): return x**2 # 生成随机数据 np.random.seed(42) x_values np.random.uniform(-5, 5, 10) weights np.random.dirichlet(np.ones(10)) # 计算Jensen不等式两边 left_side quadratic(np.sum(weights * x_values)) right_side np.sum(weights * quadratic(x_values)) print(ff(∑λx): {left_side:.4f}) print(f∑λf(x): {right_side:.4f}) print(f验证结果{left_side right_side})运行结果示例f(∑λx): 1.2345 ∑λf(x): 3.4567 验证结果True为了更直观理解我们可以绘制函数图像和验证点x_range np.linspace(-5, 5, 100) plt.plot(x_range, quadratic(x_range), labelf(x)x²) plt.scatter(x_values, quadratic(x_values), colorred, label数据点) plt.scatter(np.sum(weights*x_values), left_side, colorgreen, labelf(∑λx), s100) plt.legend() plt.title(Jensen不等式验证二次函数) plt.show()3. 扩展到常见函数形式Jensen不等式不仅适用于简单的二次函数对于机器学习中常见的函数形式也同样适用。我们来看几个典型例子。3.1 指数函数指数函数f(x)eˣ是凸函数在概率模型中经常出现。def exponential(x): return np.exp(x) # 使用之前的数据验证 exp_left exponential(np.sum(weights * x_values)) exp_right np.sum(weights * exponential(x_values)) print(f指数函数验证{exp_left exp_right})3.2 对数函数对数函数在(0,∞)上是凹函数因此不等式方向会反转。def logarithmic(x): return np.log(x) # 生成正数数据 pos_x np.random.uniform(0.1, 10, 10) log_left logarithmic(np.sum(weights * pos_x)) log_right np.sum(weights * logarithmic(pos_x)) print(f对数函数验证{log_left log_right}) # 注意方向3.3 概率分布中的应用在概率论中Jensen不等式表现为E[f(X)] ≥ f(E[X])对于凸函数f。我们可以用正态分布来验证mu, sigma 0, 1 samples norm.rvs(mu, sigma, size1000) exp_samples np.exp(samples) print(fE[e^X]: {np.mean(exp_samples):.4f}) print(fe^E[X]: {np.exp(np.mean(samples)):.4f}) print(f验证{np.mean(exp_samples) np.exp(np.mean(samples))})4. 机器学习中的实际应用理解了Jensen不等式的基本形式后我们来看它在机器学习中的两个典型应用场景。4.1 交叉熵损失函数交叉熵损失H(p,q)-∑p(x)logq(x)的凸性保证了优化过程的稳定性。我们可以验证对于任意概率分布p和qdef cross_entropy(p, q): return -np.sum(p * np.log(q)) # 生成两个概率分布 p np.random.dirichlet(np.ones(5)) q np.random.dirichlet(np.ones(5)) # 验证凸性 lambda_ 0.3 q_mix lambda_*q (1-lambda_)*p ce_mix cross_entropy(p, q_mix) ce_avg lambda_*cross_entropy(p,q) (1-lambda_)*cross_entropy(p,p) print(f混合分布交叉熵{ce_mix:.4f}) print(f交叉熵的加权平均{ce_avg:.4f}) print(f验证{ce_mix ce_avg})4.2 KL散度的非负性KL散度KL(p||q)∑p(x)log(p(x)/q(x))的非负性也可以通过Jensen不等式证明def kl_divergence(p, q): return np.sum(p * np.log(p/q)) kl kl_divergence(p, q) print(fKL散度值{kl:.4f}) # 总是非负5. 可视化理解与进阶思考为了更深入理解Jensen不等式我们可以从几何角度进行可视化分析。5.1 弦与函数的比较对于凸函数任意两点间的弦总是位于函数图像上方def plot_chord(f, a, b): x np.linspace(a, b, 100) y f(x) # 弦的函数 def chord(x): return f(a) (f(b)-f(a))/(b-a) * (x-a) plt.plot(x, y, label函数曲线) plt.plot(x, chord(x), label弦) plt.scatter([a, b], [f(a), f(b)], colorred) plt.legend() plt.title(凸函数的弦总是在函数上方) plot_chord(quadratic, -2, 3) plt.show()5.2 多点的凸组合对于多个点的情况我们可以观察它们的凸组合如何满足不等式def plot_multi_points(f, points, weights): x np.linspace(min(points)-1, max(points)1, 100) plt.plot(x, f(x)) # 绘制数据点 plt.scatter(points, f(points), colorred) # 计算凸组合 convex_comb np.sum(weights * points) plt.scatter(convex_comb, f(convex_comb), colorgreen, s100) # 绘制加权平均点 weighted_avg np.sum(weights * f(points)) plt.scatter(convex_comb, weighted_avg, colorblue, s100) plt.legend([函数曲线, 数据点, f(∑λx), ∑λf(x)]) plot_multi_points(quadratic, np.array([-3, -1, 2, 4]), np.array([0.1, 0.4, 0.3, 0.2])) plt.show()6. 常见误区与验证技巧在实际应用中有几个常见的误区需要注意函数凸性的判断验证前务必确认函数的凸性。例如f(x)x³在全体实数上不是凸函数f(x)|x|是凸函数但不可导def cubic(x): return x**3 # 在负数区间不满足凸性 neg_x np.random.uniform(-5, 0, 10) cubic_left cubic(np.sum(weights * neg_x)) cubic_right np.sum(weights * cubic(neg_x)) print(f立方函数验证负数区间{cubic_left cubic_right}) # 可能不成立权重的要求权重λ必须满足∑λ1且λ≥0。如果违反这个条件不等式可能不成立invalid_weights np.random.uniform(0, 1, 10) # 未归一化 try: invalid_left quadratic(np.sum(invalid_weights * x_values)) invalid_right np.sum(invalid_weights * quadratic(x_values)) print(f无效权重验证{invalid_left invalid_right}) # 可能不成立 except: print(权重不符合要求)数值稳定性在实际计算中特别是涉及指数和对数运算时需要注意数值稳定性问题def stable_log_sum(x): max_val np.max(x) return max_val np.log(np.sum(np.exp(x - max_val))) large_x np.array([1000, 1001, 1002]) print(f直接计算{np.log(np.sum(np.exp(large_x)))}) # 可能溢出 print(f稳定计算{stable_log_sum(large_x)}) # 数值稳定

相关新闻