用Python玩转强化学习:手把手教你用Numpy和Matplotlib求解赌徒问题(附完整代码)

发布时间:2026/5/30 3:33:04

用Python玩转强化学习:手把手教你用Numpy和Matplotlib求解赌徒问题(附完整代码) 用Python玩转强化学习手把手教你用Numpy和Matplotlib求解赌徒问题附完整代码强化学习作为机器学习的重要分支其核心思想是通过与环境的交互来学习最优策略。马尔科夫决策过程MDP是强化学习的理论基础而赌徒问题则是一个经典的MDP示例。本文将带你从零开始用Python实现策略迭代和值迭代算法并通过可视化手段直观展示学习过程。1. 环境准备与问题建模在开始编码之前我们需要明确赌徒问题的具体规则和MDP建模方式。假设一个赌徒初始有一定金额的赌资每次可以选择下注1到min(s,100-s)的金额s为当前金额。硬币正面朝上的概率为ph此时赌徒赢得下注金额反面朝上则输掉下注金额。游戏在赌徒达到100美元或破产时结束。首先安装必要的Python库pip install numpy matplotlib seaborn然后导入所需的库import numpy as np import matplotlib.pyplot as plt import seaborn as sns定义赌徒问题的MDP参数GOAL 100 # 目标金额 PH 0.4 # 硬币正面朝上的概率 THETA 1e-9 # 收敛阈值 GAMMA 1 # 折扣因子2. 策略迭代实现策略迭代算法由两个交替进行的步骤组成策略评估和策略改进。下面我们逐步实现这两个步骤。2.1 策略评估策略评估的目的是计算当前策略下的状态值函数。我们使用迭代法来逼近真实值函数def policy_evaluation(policy, state_values, phPH, gammaGAMMA, thetaTHETA): while True: delta 0 for s in range(1, GOAL): a policy[s] # 可能的下一状态和奖励 s_win s a s_lose s - a # 计算期望回报 new_value ph * (gamma * state_values[s_win]) \ (1 - ph) * (gamma * state_values[s_lose]) delta max(delta, abs(state_values[s] - new_value)) state_values[s] new_value if delta theta: break return state_values2.2 策略改进在策略改进步骤中我们基于当前值函数寻找更优策略def policy_improvement(policy, state_values, phPH, gammaGAMMA): policy_stable True for s in range(1, GOAL): old_action policy[s] possible_actions range(1, min(s, GOAL - s) 1) # 计算每个动作的期望回报 action_returns [] for a in possible_actions: s_win s a s_lose s - a ret ph * (gamma * state_values[s_win]) \ (1 - ph) * (gamma * state_values[s_lose]) action_returns.append(ret) # 选择最优动作 best_action possible_actions[np.argmax(action_returns)] policy[s] best_action if old_action ! best_action: policy_stable False return policy, policy_stable2.3 完整策略迭代流程将上述两个步骤组合起来形成完整的策略迭代算法def policy_iteration(phPH, gammaGAMMA, thetaTHETA): # 初始化 state_values np.zeros(GOAL 1) state_values[GOAL] 1.0 # 达到目标时的奖励 policy np.ones(GOAL 1, dtypeint) # 初始策略总是下注1 iteration 0 while True: iteration 1 # 策略评估 state_values policy_evaluation(policy, state_values, ph, gamma, theta) # 策略改进 policy, policy_stable policy_improvement(policy, state_values, ph, gamma) if policy_stable: break return policy, state_values, iteration3. 值迭代实现值迭代算法将策略评估和策略改进合并为一个步骤直接寻找最优值函数。3.1 值迭代核心算法def value_iteration(phPH, gammaGAMMA, thetaTHETA): state_values np.zeros(GOAL 1) state_values[GOAL] 1.0 # 达到目标时的奖励 policy np.zeros(GOAL 1, dtypeint) iteration 0 while True: iteration 1 delta 0 for s in range(1, GOAL): old_value state_values[s] possible_actions range(1, min(s, GOAL - s) 1) # 计算每个动作的期望回报 action_returns [] for a in possible_actions: s_win s a s_lose s - a ret ph * (gamma * state_values[s_win]) \ (1 - ph) * (gamma * state_values[s_lose]) action_returns.append(ret) # 更新值函数 new_value max(action_returns) delta max(delta, abs(old_value - new_value)) state_values[s] new_value if delta theta: break # 提取最优策略 for s in range(1, GOAL): possible_actions range(1, min(s, GOAL - s) 1) action_returns [] for a in possible_actions: s_win s a s_lose s - a ret ph * (gamma * state_values[s_win]) \ (1 - ph) * (gamma * state_values[s_lose]) action_returns.append(ret) best_action possible_actions[np.argmax(action_returns)] policy[s] best_action return policy, state_values, iteration4. 结果可视化与分析4.1 值函数可视化我们可以使用Matplotlib绘制两种算法得到的值函数def plot_value_functions(pi_values, vi_values): plt.figure(figsize(12, 6)) plt.plot(pi_values, labelPolicy Iteration) plt.plot(vi_values, labelValue Iteration) plt.xlabel(Capital) plt.ylabel(Value) plt.title(Value Functions Comparison) plt.legend() plt.grid(True) plt.show()4.2 最优策略可视化最优策略的展示可以帮助我们理解在不同资本下应该采取的最佳行动def plot_policies(pi_policy, vi_policy): plt.figure(figsize(12, 6)) plt.plot(pi_policy, labelPolicy Iteration) plt.plot(vi_policy, labelValue Iteration) plt.xlabel(Capital) plt.ylabel(Optimal Stake) plt.title(Optimal Policies Comparison) plt.legend() plt.grid(True) plt.show()4.3 运行与结果分析让我们运行两种算法并比较结果# 运行策略迭代 pi_policy, pi_values, pi_iter policy_iteration() # 运行值迭代 vi_policy, vi_values, vi_iter value_iteration() # 可视化结果 plot_value_functions(pi_values, vi_values) plot_policies(pi_policy, vi_policy) print(fPolicy Iteration converged in {pi_iter} iterations) print(fValue Iteration converged in {vi_iter} iterations)从结果中可以观察到几个有趣的现象当资本较少时最优策略倾向于激进的下注方式随着资本增加策略会变得更加保守两种算法得到的值函数非常接近但最优策略在某些资本区间存在差异值迭代通常比策略迭代收敛得更快5. 参数影响与扩展实验5.1 硬币概率的影响硬币正面朝上的概率ph对最优策略有显著影响。我们可以比较不同ph值下的策略def compare_probabilities(): ph_values [0.4, 0.45, 0.5, 0.55] plt.figure(figsize(12, 8)) for ph in ph_values: _, policy, _ value_iteration(phph) plt.plot(policy, labelfph{ph}) plt.xlabel(Capital) plt.ylabel(Optimal Stake) plt.title(Optimal Policy for Different Probabilities) plt.legend() plt.grid(True) plt.show() compare_probabilities()5.2 收敛阈值的影响收敛阈值theta决定了算法的精度和运行时间def compare_thetas(): theta_values [1e-3, 1e-6, 1e-9, 1e-12] iterations [] for theta in theta_values: _, _, iter_count value_iteration(thetatheta) iterations.append(iter_count) plt.figure(figsize(10, 5)) plt.plot(theta_values, iterations, o-) plt.xscale(log) plt.xlabel(Theta (log scale)) plt.ylabel(Iterations to Converge) plt.title(Convergence Speed vs. Precision) plt.grid(True) plt.show() compare_thetas()5.3 策略迭代与值迭代的比较我们可以从多个维度比较两种算法特性策略迭代值迭代收敛速度通常较慢通常较快每次迭代复杂度较高较低中间结果包含完整策略只有值函数适用场景策略变化不大的问题通用实现难度较复杂较简单在实际项目中值迭代通常更受欢迎因为它实现简单且收敛速度快。但在某些策略变化缓慢的问题中策略迭代可能更高效。6. 实用技巧与常见问题6.1 调试技巧检查值函数更新在每次迭代后打印几个关键状态的值确保它们按预期变化验证边界条件特别注意资本为0和目标金额时的处理可视化中间结果绘制每次迭代后的值函数观察收敛过程def debug_value_iteration(): state_values np.zeros(GOAL 1) state_values[GOAL] 1.0 history [] for i in range(10): # 只运行10次迭代用于调试 delta 0 for s in range(1, GOAL): old_value state_values[s] possible_actions range(1, min(s, GOAL - s) 1) action_returns [] for a in possible_actions: s_win s a s_lose s - a ret PH * (GAMMA * state_values[s_win]) \ (1 - PH) * (GAMMA * state_values[s_lose]) action_returns.append(ret) new_value max(action_returns) delta max(delta, abs(old_value - new_value)) state_values[s] new_value history.append(state_values.copy()) print(fIteration {i1}, delta: {delta}) # 绘制前几次迭代的值函数变化 plt.figure(figsize(12, 6)) for i, values in enumerate(history[:5]): plt.plot(values, labelfIter {i1}) plt.xlabel(Capital) plt.ylabel(Value) plt.title(Value Function Evolution (First 5 Iterations)) plt.legend() plt.grid(True) plt.show() # debug_value_iteration()6.2 性能优化向量化操作利用NumPy的向量化运算替代循环提前终止当delta小于theta时立即终止迭代并行计算对于大型MDP可以考虑并行化状态更新def vectorized_value_iteration(phPH, gammaGAMMA, thetaTHETA): state_values np.zeros(GOAL 1) state_values[GOAL] 1.0 policy np.zeros(GOAL 1, dtypeint) while True: old_values state_values.copy() for s in range(1, GOAL): max_a min(s, GOAL - s) # 向量化计算所有可能动作的回报 actions np.arange(1, max_a 1) s_win s actions s_lose s - actions returns ph * gamma * state_values[s_win] \ (1 - ph) * gamma * state_values[s_lose] state_values[s] np.max(returns) if np.max(np.abs(state_values - old_values)) theta: break # 向量化策略提取 for s in range(1, GOAL): max_a min(s, GOAL - s) actions np.arange(1, max_a 1) s_win s actions s_lose s - actions returns ph * gamma * state_values[s_win] \ (1 - ph) * gamma * state_values[s_lose] policy[s] actions[np.argmax(returns)] return policy, state_values6.3 常见问题解答Q: 为什么我的值函数不收敛A: 检查折扣因子gamma是否设置合理确保theta值不过小验证奖励函数是否正确。Q: 最优策略看起来不合理怎么办A: 确认硬币概率ph设置正确检查动作空间定义是否准确特别是边界条件。Q: 如何处理更大的状态空间A: 考虑使用函数逼近方法替代表格法或采用分层强化学习技术。Q: 两种算法结果不一致正常吗A: 在赌徒问题中由于存在多个等价最优策略结果可能有差异但值函数应该相近。7. 实际应用与扩展思考虽然我们以赌徒问题为例但这些技术可以应用于许多实际场景投资组合优化确定在不同市场条件下的最佳投资比例库存管理决定最优的库存补充策略机器人路径规划在不确定环境中寻找最优路径扩展思考方向非确定性转移概率的处理部分可观察MDP(POMDP)的扩展使用深度学习处理连续状态和动作空间多智能体强化学习中的策略交互# 示例将赌徒问题扩展到可变下注比例 def fractional_betting(ph0.5, gamma1, theta1e-6): state_values np.zeros(GOAL 1) state_values[GOAL] 1.0 policy np.zeros(GOAL 1) # 现在存储下注比例 while True: delta 0 for s in range(1, GOAL): old_value state_values[s] # 考虑下注比例为0%到100% fractions np.linspace(0, 1, 101) returns [] for f in fractions: stake int(f * min(s, GOAL - s)) if stake 0: returns.append(state_values[s]) continue s_win s stake s_lose s - stake ret ph * gamma * state_values[s_win] \ (1 - ph) * gamma * state_values[s_lose] returns.append(ret) new_value max(returns) delta max(delta, abs(old_value - new_value)) state_values[s] new_value policy[s] fractions[np.argmax(returns)] if delta theta: break return policy, state_values # frac_policy, frac_values fractional_betting() # plt.plot(frac_policy[1:GOAL]) # plt.title(Optimal Betting Fraction) # plt.show()在实现这些扩展时你会发现强化学习的强大之处在于其框架的通用性。同样的算法结构只需调整状态、动作和奖励的定义就能应用于完全不同的问题领域。

相关新闻