Python实战:用SARSA算法训练AI玩迷宫游戏(附完整代码)

发布时间:2026/6/7 11:09:53

Python实战:用SARSA算法训练AI玩迷宫游戏(附完整代码) Python实战用SARSA算法训练AI玩迷宫游戏附完整代码最近在GitHub上看到一个有趣的迷宫游戏项目让我想起刚接触强化学习时用SARSA算法训练AI玩迷宫的经历。那种看着智能体从随机碰撞到最终找到出口的成就感至今难忘。本文将带你从零实现这个经典案例不仅包含可直接运行的代码还会分享调试过程中的实用技巧。1. 环境搭建与问题定义我们先来构建迷宫游戏环境。这个5x5的网格世界中S代表起点G代表目标#是墙壁.是可通行区域import numpy as np class MazeEnv: def __init__(self): self.grid np.array([ [S, ., ., ., .], [., #, ., #, .], [., #, ., #, .], [., #, ., #, G], [., ., ., ., .] ]) self.state (0, 0) self.actions [up, down, left, right] def reset(self): self.state (0, 0) return self.state def step(self, action): x, y self.state if action up: x max(0, x-1) elif action down: x min(4, x1) elif action left: y max(0, y-1) elif action right: y min(4, y1) if self.grid[x][y] #: # 撞墙 return self.state, -1, False elif self.grid[x][y] G: # 到达终点 return (x, y), 10, True else: self.state (x, y) return self.state, -0.1, False注意奖励设计是强化学习的关键。这里采用-1惩罚撞墙-0.1鼓励快速到达终点10奖励成功。2. SARSA算法核心实现让我们分解SARSA算法的实现步骤初始化Q表状态-动作对的预期回报ε-greedy策略平衡探索与利用SARSA更新规则Q(s,a) ← Q(s,a) α[r γQ(s,a) - Q(s,a)]完整实现如下class SARSAgent: def __init__(self, env, alpha0.1, gamma0.9, epsilon0.1): self.env env self.alpha alpha # 学习率 self.gamma gamma # 折扣因子 self.epsilon epsilon # 探索率 self.Q {} self._init_q_table() def _init_q_table(self): for i in range(5): for j in range(5): for a in self.env.actions: self.Q[(i,j), a] 0 # 初始化为0 def choose_action(self, state): if np.random.random() self.epsilon: # 探索 return np.random.choice(self.env.actions) else: # 利用 q_values [self.Q[state, a] for a in self.env.actions] return self.env.actions[np.argmax(q_values)] def learn(self, state, action, reward, next_state, next_action): current_q self.Q[state, action] next_q self.Q[next_state, next_action] self.Q[state, action] current_q self.alpha * ( reward self.gamma * next_q - current_q ) def train(self, episodes1000): rewards [] for ep in range(episodes): state self.env.reset() action self.choose_action(state) total_reward 0 done False while not done: next_state, reward, done self.env.step(action) next_action self.choose_action(next_state) self.learn(state, action, reward, next_state, next_action) state, action next_state, next_action total_reward reward rewards.append(total_reward) if ep % 100 0: print(fEpisode {ep}, Reward: {total_reward}) return rewards3. 训练过程与可视化训练1000次后我们可以观察学习曲线import matplotlib.pyplot as plt env MazeEnv() agent SARSAgent(env) rewards agent.train() plt.plot(rewards) plt.xlabel(Episode) plt.ylabel(Total Reward) plt.title(SARSA Learning Curve) plt.show()典型训练过程会经历三个阶段随机探索期0-200轮智能体频繁撞墙奖励波动大策略形成期200-600轮开始找到可行路径奖励稳步上升稳定优化期600轮后微调路径奖励趋于稳定为了直观展示学习成果可以可视化最终策略def visualize_policy(agent): arrows {up:↑, down:↓, left:←, right:→} for i in range(5): row [] for j in range(5): if env.grid[i][j] in [#, S, G]: row.append(env.grid[i][j]) else: action agent.choose_action((i,j)) row.append(arrows[action]) print( .join(row))输出示例S → → → → ↓ # → # → ↓ # → # → ↓ # → # G ↓ ← ← ← ←4. 关键参数调优指南SARSA算法性能受三个核心参数影响参数典型范围影响调整建议α (alpha)0.01-0.5学习速度 vs 稳定性从0.1开始观察收敛情况γ (gamma)0.8-0.99未来奖励的重要性长期任务选高值(0.9)ε (epsilon)0.01-0.3探索 vs 利用的平衡训练后期可逐渐降低调试时常见的现象与解决方案问题1奖励长期波动不收敛可能原因学习率α过高解决方案尝试α0.05或实现动态衰减alpha 0.1 / (1 episode*0.001)问题2智能体陷入局部最优可能原因ε值下降过快解决方案采用ε衰减策略epsilon max(0.01, 0.1 * (1 - episode/800))问题3训练后期性能突然下降可能原因过度探索破坏了已学策略解决方案设置ε最小值(如0.01)或改用Boltzmann探索策略5. 进阶优化技巧当基本实现运行稳定后可以考虑以下优化优先经验回放from collections import deque class ReplayBuffer: def __init__(self, capacity1000): self.buffer deque(maxlencapacity) def add(self, experience): self.buffer.append(experience) def sample(self, batch_size): return random.sample(self.buffer, min(batch_size, len(self.buffer))) # 在训练循环中使用 buffer ReplayBuffer() ... buffer.add((state, action, reward, next_state, next_action)) if len(buffer) 32: batch buffer.sample(32) for exp in batch: agent.learn(*exp)动态ε-greedy策略def get_epsilon(episode, min_epsilon0.01, decay_rate0.995): return max(min_epsilon, 0.1 * (decay_rate ** episode))状态泛化对于更大迷宫可以用网格特征代替坐标def extract_features(state): x, y state return [ x/4, y/4, # 归一化坐标 int(env.grid[x][y] G), # 是否看到目标 min(abs(x-gx) abs(y-gy) for gx,gy in goal_positions) # 曼哈顿距离 ]6. 与其他算法的对比实验为了展示SARSA特性我们与Q-learning进行对比class QLearningAgent(SARSAgent): def learn(self, state, action, reward, next_state): current_q self.Q[state, action] max_next_q max([self.Q[next_state, a] for a in self.env.actions]) self.Q[state, action] current_q self.alpha * ( reward self.gamma * max_next_q - current_q ) # 训练比较 q_agent QLearningAgent(env) q_rewards q_agent.train() plt.plot(rewards, labelSARSA) plt.plot(q_rewards, labelQ-learning) plt.legend()典型对比结果Q-learning前期学习更快但稳定性较差SARSA收敛较慢但最终策略更鲁棒实际项目中如果环境存在危险状态如悬崖行走问题SARSA的保守策略通常更安全。7. 工程实践建议在真实项目部署时还需要考虑性能监控def test_performance(agent, runs100): success 0 steps [] for _ in range(runs): state env.reset() step 0 done False while not done and step 100: action agent.choose_action(state) state, _, done env.step(action) step 1 if done: success 1 steps.append(step) return success/runs, np.mean(steps)模型保存与加载import pickle def save_model(agent, path): with open(path, wb) as f: pickle.dump(agent.Q, f) def load_model(agent, path): with open(path, rb) as f: agent.Q pickle.load(f)超参数搜索from itertools import product param_grid { alpha: [0.01, 0.05, 0.1], gamma: [0.8, 0.9, 0.95], epsilon: [0.05, 0.1, 0.2] } best_reward -float(inf) for params in product(*param_grid.values()): agent SARSAgent(env, **dict(zip(param_grid.keys(), params))) rewards agent.train(episodes300) if np.mean(rewards[-50:]) best_reward: best_reward np.mean(rewards[-50:]) best_params params

相关新闻