在本项目中,通过深度 Q 学习(Deep Q-Learning, DQN)训练一个自动玩贪吃蛇游戏的模型。训练过程融合了现代深度学习与强化学习的多种关键技术,如全连接神经网络、循环神经网络(RNN)、长短期记忆网络(LSTM)、经验回放、ε-greedy 策略等。
一. 环境建模:SnakeEnv
1. 代码 env_snake.py
-
说明(和模型相关)
-
_get_state()方法是生成神经网络需要的特征。这里生成一个10 维向量。每个向量的意义代码注释 -
step()方法本身是在判断蛇下一步的状态(是否撞墙,吃到食物,计算分数等);step()方法中reward是在给模型的操作进行打分(吃到食物加分,靠近食物加减分,存活没吃到食物的加减分)
-
import random
import pygame
import numpy as np
class SnakeEnv:
def __init__(self, width=800, height=600, cell_size=20):
pygame.init() # 初始化pygame模块
self.width = width # 游戏窗口宽度
self.height = height # 游戏窗口高度
self.cell_size = cell_size # 每个单元格的像素大小
self.map_width = width // cell_size # 地图宽度(格子数)
self.map_height = height // cell_size # 地图高度(格子数)
self.screen = pygame.display.set_mode((self.width, self.height)) # 创建窗口
pygame.display.set_caption('RL Snake') # 设置窗口标题
self.clock = pygame.time.Clock() # 控制帧率
self.snake_speed = 15 # 游戏速度(帧率)
# 定义方向常量
self.UP = 0
self.DOWN = 1
self.LEFT = 2
self.RIGHT = 3
self.HEAD = 0 # 蛇头在列表中的索引
self.reset() # 初始化游戏状态
def reset(self):
# 初始化蛇的位置
startx = random.randint(3, self.map_width - 8)
starty = random.randint(3, self.map_height - 8)
self.snake = [{'x': startx, 'y': starty},
{'x': startx - 1, 'y': starty},
{'x': startx - 2, 'y': starty}]
self.direction = self.RIGHT # 初始方向向右
self.food = self._random_food() # 随机放置食物
self.score = 0 # 初始得分
self.done = False # 游戏是否结束
return self._get_state() # 返回初始状态向量
def _random_food(self):
# 在地图随机生成一个不与蛇身体重叠的食物
while True:
pos = {'x': random.randint(0, self.map_width - 1),
'y': random.randint(0, self.map_height - 1)}
if pos not in self.snake:
return pos
def _get_state(self):
# 构建10维状态向量
head = self.snake[self.HEAD]
# 当前方向与坐标偏移映射
direction_map = {
self.UP: (0, -1), # 向上走,x不变,y减1
self.DOWN: (0, 1), # 向下走,x不变,y加1
self.LEFT: (-1, 0), # 向左走,x减1,y不变
self.RIGHT: (1, 0) # 向右走,x加1,y不变
}
# 当前方向对应的偏移
# 图像坐标系的原点在左上角。
dx, dy = direction_map[self.direction]
# 左边方向的偏移 = 逆时针旋转 90°
left_dx, left_dy = -dy, dx
# 右边方向的偏移 = 顺时针旋转 90°
right_dx, right_dy = dy, -dx
# 判断某个位置是否危险(撞墙或撞自己)
def is_danger(x, y):
if x < 0 or x >= self.map_width or y < 0 or y >= self.map_height:
return 1.0
for segment in self.snake[1:]:
if segment['x'] == x and segment['y'] == y:
return 1.0
return 0.0
# 判断前、左、右是否危险
danger_front = is_danger(head['x'] + dx, head['y'] + dy)
danger_left = is_danger(head['x'] + left_dx, head['y'] + left_dy)
danger_right = is_danger(head['x'] + right_dx, head['y'] + right_dy)
# 判断食物方向是否在前、左、右
food_dx = self.food['x'] - head['x']
food_dy = self.food['y'] - head['y']
def food_dir_is(dx1, dy1):
return 1.0 if (np.sign(food_dx) == dx1 and np.sign(food_dy) == dy1) else 0.0
food_front = food_dir_is(dx, dy)
food_left = food_dir_is(left_dx, left_dy)
food_right = food_dir_is(right_dx, right_dy)
# 当前移动方向的 one-hot 向量(上、下、左、右)
dir_onehot = [0.0, 0.0, 0.0, 0.0]
dir_onehot[self.direction] = 1.0
# 最终状态向量(10维)
state = [
danger_left, danger_front, danger_right,
food_left, food_front, food_right,
*dir_onehot
]
return np.array(state, dtype=np.float32)
def step(self, action):
if self.done:
return self._get_state(), 0, True, {}
# 更新方向(避免直接反向)
if (action == self.UP and self.direction != self.DOWN) or \
(action == self.DOWN and self.direction != self.UP) or \
(action == self.LEFT and self.direction != self.RIGHT) or \
(action == self.RIGHT and self.direction != self.LEFT):
self.direction = action
# 记录移动前蛇头位置(用于奖励计算)
head = self.snake[self.HEAD]
prev_dist = abs(head['x'] - self.food['x']) + abs(head['y'] - self.food['y'])
# 移动蛇
self._move()
reward = 0
# 判断游戏是否结束
is_alive, crash = self._is_alive()
if not is_alive:
self.done = True
if crash == 'wall':
reward -= 10
else:
reward -= 20
else:
# 计算新位置距离
new_head = self.snake[self.HEAD]
new_dist = abs(new_head['x'] - self.food['x']) + abs(new_head['y'] - self.food['y'])
# 奖励蛇“靠近”食物
reward += (prev_dist - new_dist) * 0.5
# 距离越近奖励越高(即使没吃到)
reward += 1.0 / (new_dist + 1)
if self._is_eat_food(): # 吃到食物奖励
self.score += 1
reward += 10
self.food = self._random_food()
else:
self.snake.pop() # 没吃到就删尾(移动)
return self._get_state(), reward, self.done, {}
def _move(self):
# 根据方向添加新的蛇头位置
head = self.snake[self.HEAD]
if self.direction == self.UP:
new_head = {'x': head['x'], 'y': head['y'] - 1}
elif self.direction == self.DOWN:
new_head = {'x': head['x'], 'y': head['y'] + 1}
elif self.direction == self.LEFT:
new_head = {'x': head['x'] - 1, 'y': head['y']}
elif self.direction == self.RIGHT:
new_head = {'x': head['x'] + 1, 'y': head['y']}
self.snake.insert(0, new_head)
def _is_alive(self):
# 判断是否撞墙或撞自己
head = self.snake[self.HEAD]
if head['x'] < 0 or head['x'] >= self.map_width or \
head['y'] < 0 or head['y'] >= self.map_height:
return False, 'wall'
for segment in self.snake[1:]:
if segment['x'] == head['x'] and segment['y'] == head['y']:
return False, 'snake'
return True, ''
def _is_eat_food(self):
head = self.snake[self.HEAD]
return head['x'] == self.food['x'] and head['y'] == self.food['y']
def render(self):
for event in pygame.event.get(): # 防止窗口卡死
if event.type == pygame.QUIT:
self.close()
self.screen.fill((0, 0, 0)) # 背景黑色
# 画食物(红色)
fx, fy = self.food['x'] * self.cell_size, self.food['y'] * self.cell_size
pygame.draw.rect(self.screen, (255, 0, 0), (fx, fy, self.cell_size, self.cell_size))
# 画蛇身体(深青色)
for seg in self.snake:
x, y = seg['x'] * self.cell_size, seg['y'] * self.cell_size
pygame.draw.rect(self.screen, (0, 139, 139), (x, y, self.cell_size, self.cell_size))
pygame.display.flip()
self.clock.tick(self.snake_speed)
def close(self):
pygame.quit()
二、全连接神经网络
将池化后提取到的特征信息映射为最终的输出类别或回归结果,是神经网络的决策层。
1. 代码 q_network.py
这里是采用的三层全连层。一般采用层数如下:
| 层数(隐藏层) | 特点 | 适用情况 |
|---|---|---|
| 1 层(最浅) | 只能学习“非常简单”的函数 | 小数据量,小任务(线性或近似线性) |
| 2~3 层 | 能处理常见复杂映射,训练速度快 | 中小规模任务(如强化学习) |
| 4~6 层 | 表达能力强,能学到更复杂结构 | 图像、文本等特征模式复杂任务 |
| 10 层以上 | 深度网络,可能出现梯度消失或过拟合问题 | 用残差连接、BN 等技术辅助才好训练 |
import torch
import torch.nn as nn
import torch.nn.functional as F
# 这是一个简单的多层感知机,用于输入 10 维状态,输出 4 个动作的 Q 值(上、下、左、右):
# 输入是 [danger_left, danger_front, danger_right,food_dir_left, food_dir_front, food_dir_right,current_direction (one-hot)]
# 输出是 [Q(上), Q(下), Q(左), Q(右)]
class QNetwork(nn.Module):
def __init__(self, state_dim=10, action_dim=4, hidden_dim=128):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x)) # 第1隐藏层
x = F.relu(self.fc2(x)) # 第2隐藏层
x = F.relu(self.fc3(x)) # 第3隐藏层
return self.out(x) # 输出每个动作的Q值
if __name__ == '__main__':
# 初始化模型
model = QNetwork()
# 构造一个假的 state 输入(比如一组10维状态)
dummy_state = torch.randn(1, 10) # shape = [batch=1, state_dim=10]
# 执行前向传播
q_values = model(dummy_state)
# 打印结果
print("Q-values for 4 actions:", q_values)
三、Q-Learning 与 DQN 简介
1. 简介
1.1. Q-Learning
Q-Learning 是一种无模型的强化学习方法,用来学习 "状态-动作" 价值函数 Q(s, a),从而制定最优策略。
1.2. DQN
DQN 是 Q-Learning 的深度学习版本,用神经网络近似 Q 函数:
在训练中,我们使用目标 Q 值:(Q 值表示:在状态 s 下做动作 a,未来可能获得的“总奖励”有多少)
通过最小化预测 Q 值和目标 Q 值之间的均方误差(MSE)进行学习。
DQN 使用两个网络:
- 主网络:用于预测当前动作的 Q 值并更新
- 目标网络:延迟更新,用于计算目标 Q 值
每隔固定步数将主网络参数拷贝到目标网络,以减少震荡。
2. 代码 dqn_aget.py
DQNAgent 是 DQN 强化学习算法中的智能体类(Agent),它负责从环境中学习、预测动作、优化模型,是整个智能决策系统的“大脑”。
2.1. 核心职责
| 职责 | 作用 |
|---|---|
选择动作 select_action()
|
根据当前状态,预测每个动作的 Q 值,并根据 ε-greedy 策略选出最终动作。 关于 ε-greedy 看最后面的补充 |
学习优化 update()
|
从经验回放中采样,更新 Q 网络参数。 |
同步目标网络 update_target()
|
定期将主 Q 网络的参数复制给目标网络。 |
| 维护 ε-greedy 策略 | 控制探索与利用之间的平衡(ε 随时间衰减)。 |
import torch
import torch.nn.functional as F
import numpy as np
from q_network import QNetwork
class DQNAgent:
# state_dim: 输入维度[
# danger_left, danger_front, danger_right, # 蛇前方是否有危险(墙或身体)
# food_dir_left, food_dir_front, food_dir_right, # 食物在蛇前、左、右的方向
# current_direction (one-hot) # 当前移动方向(如 [0,0,1,0] 表示左)
# ]
# action_dim: 输出维度: [Q(上), Q(下), Q(左), Q(右)]
#
# gamma: 折扣因; 说明:总奖励 = 当前奖励 + 将来奖励的折扣和
#
# ε-greedy 策略参数
# epsilon_start: 初始探索率,训练刚开始时有多少概率“瞎试”动作(通常设为 1.0)
# epsilon_end: 最小探索率,训练后期保留多少“探索”(一般在 0.01 ~ 0.1 之间)
# 目的:即使训练后期,也保留一部分随机性以避免“陷入局部最优”。
# epsilon_decay: 衰减速度,控制探索率下降的“步数尺度”. 从 epsilon_start 到 epsilon_end 的步数
# 越小 → 衰减越快 → 越快开始利用(训练速度快,但可能不稳)
# 越大 → 衰减越慢 → 更长时间探索(训练稳但慢)
def __init__(self, state_dim, action_dim, device,
gamma=0.99, lr=1e-3, batch_size=64,
epsilon_start=1.0, epsilon_end=0.05, epsilon_decay=2000): # 状态维度(输入维度)
self.state_dim = state_dim
# 动作数量(输出维度)
self.action_dim = action_dim
# 运行在哪个设备(CPU或GPU)
self.device = device
# 初始化主网络 Q(s, a)
self.q_net = QNetwork(state_dim, action_dim).to(device)
# 初始化目标网络 Q_target(s, a)
self.target_q_net = QNetwork(state_dim, action_dim).to(device)
# 目标网络参数拷贝自主网络
self.target_q_net.load_state_dict(self.q_net.state_dict())
self.target_q_net.eval() # 不参与训练,只用于预测
# 优化器:Adam
self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)
# 折扣因子 gamma
self.gamma = gamma
# 每次训练使用的样本数量
self.batch_size = batch_size
# ε-greedy 策略参数
# 当前的探索率
# 在 ε-greedy 策略中,每次选择动作时,有 epsilon 的概率随机选一个动作,而不是选 Q 值最大的动作。
self.epsilon_start = epsilon_start # 保存初始值
self.epsilon = epsilon_start # 当前 epsilon
# 探索率的最小值,即训练进行到后期后,不会继续减少探索了,保持一个最低限度的随机性。
# 保持少量随机性可以避免模型陷入局部最优
self.epsilon_end = epsilon_end
# 探索率衰减的“时间尺度”——训练进行多少步后 epsilon 会下降到接近 epsilon_end。
self.epsilon_decay = epsilon_decay
# 用于控制探索率 ε(epsilon)的衰减速度
# 训练过程中递增的步数计数器,每调用一次 select_action(),它就加 1
self.learn_step = 0
# 选择动作(带ε-greedy策略)
def select_action(self, state, force_random=False):
# 强制只用随机动作,不使用模型预测的 Q 值动作。时不记录动作
if force_random:
return np.random.randint(self.action_dim)
# 每次调用就增加 learn 步数
# 记录了“模型已经学习(或者决策)了多少步”
self.learn_step += 1
# 衰减后的 epsilon(指数下降)
# 它的作用是让 ε(探索概率)随着训练步数 learn_step 逐步 从初始值 ε_start 衰减到 ε_end
self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
np.exp(-1. * self.learn_step / self.epsilon_decay)
if np.random.rand() < self.epsilon:
# 随机选择动作(探索)
return np.random.randint(self.action_dim)
else:
# 模型预测 Q 值,选择最大 Q 的动作(利用)
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)
# 不计算梯度
# 这里是 select_action 选择动作,这里是在做推理
with torch.no_grad():
q_values = self.q_net(state_tensor)
# print(f"Q-values: {q_values.cpu().numpy().squeeze()}")
return q_values.argmax().item()
# 从经验池中采样并更新网络
def update(self, replay_buffer, writer, global_step=None):
if len(replay_buffer) < self.batch_size:
return # 如果经验不够,就跳过更新
# 从经验池中随机采样一个 batch = 64
states, actions, rewards, next_states, dones = replay_buffer.sample(self.batch_size)
# 将数据转移到对应设备
states = states.to(self.device)
actions = actions.to(self.device)
rewards = rewards.to(self.device)
next_states = next_states.to(self.device)
dones = dones.to(self.device)
# DQN 核心的训练更新逻辑
# 用当前网络预测的 Q 值和目标 Q 值计算差距(损失),并反向传播优化模型。
#
# q_values当前网络预测 Q 值:Q(s, a)
# self.q_net(states) → 得到每个状态对应的所有动作的 Q 值(shape: [batch_size, action_dim]);
# actions.unsqueeze(1) → 把每个 action 的索引变成 shape [batch_size, 1],以便用于 gather;
# .gather(1, ...) → 提取当前经验中实际采取的动作 a 的 Q 值 Q(s,a);
# .squeeze(1) → 去掉维度 [batch_size, 1] → [batch_size],得到一维向量。
q_values = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
# target_q 目标 Q 值:
with torch.no_grad():
# self.target_q_net(next_states):预测下一步状态 s' 所有动作的 Q 值。
# .max(1)[0]:对每个状态,选出最大 Q 值 max_a Q(s', a') → 目标值的关键
max_next_q = self.target_q_net(next_states).max(1)[0]
# DQN 的 目标 Q 值方程: r + γ * max(Q_target(s', a'))
target_q = rewards + self.gamma * max_next_q * (1 - dones)
max_q = q_values.max(0)[0].mean().item() # 每个样本的最大Q,再求平均
mean_q = q_values.mean().item() # 所有 Q 值的平均
# 均方误差损失函数
loss = F.mse_loss(q_values, target_q)
# 记录曲线图
if writer is not None and global_step is not None:
writer.add_scalar("q_value/max", max_q, global_step)
writer.add_scalar("q_value/mean", mean_q, global_step)
writer.add_scalar("loss", loss.item(), global_step)
# 反向传播更新 Q 网络参数
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# 用于每隔固定步数同步目标网络
def update_target(self):
self.target_q_net.load_state_dict(self.q_net.state_dict())
四、Replay Buffer(经验回放机制)
1. 说明
1.1. 作用
Replay Buffer 是一个存放与模型交互经验的“记忆库”,训练时从中随机采样小批量经验来打破数据之间的相关性,提升学习效率和稳定性。
1.2. 什么是 Replay Buffe
是一个 环形队列(FIFO),每次环境交互都会存一条“经验”进去可以理解为一个 Python 列表或队列结构,最多存 N 条最新经验
Replay Buffer 是一个存放与模型交互经验的“记忆库”,训练时从中随机采样小批量经验来打破数据之间的相关性,提升学习效率和稳定性。没有它,DQN 很容易训练不稳定甚至崩溃
1.3. 为什么需要
如果模型智不断地与环境交互获得数据,如果我们直接拿“刚刚经历的那条数据”来训练,会有两个大问题
- 数据高度相关:连续的状态和动作之间高度相关(比如连续帧图像)
- 数据分布不断变化:模型边训练边收集数据,会导致训练目标也在变
2. 代码 replay_buffer.py
#! coding=utf-8
from collections import deque
import random
import numpy as np
import torch
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
# state 是list,[
# danger_left, danger_front, danger_right, # 蛇前方是否有危险(墙或身体)
# food_dir_left, food_dir_front, food_dir_right, # 食物在蛇前、左、右的方向
# current_direction (one-hot) # 当前移动方向(如 [0,0,1,0] 表示左)
# ]
# action 是随机动作
# reward 奖励(如吃到食物 +1,死亡 -1)
# next_state 下一个状态
state = np.array(state, dtype=np.float32)
next_state = np.array(next_state, dtype=np.float32)
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
return (
torch.tensor(states, dtype=torch.float32),
torch.tensor(actions, dtype=torch.long),
torch.tensor(rewards, dtype=torch.float32),
torch.tensor(next_states, dtype=torch.float32),
torch.tensor(dones, dtype=torch.float32),
)
def __len__(self):
return len(self.buffer)
五、训练
1. 代码 train_dqn.py
import torch
import os
import numpy as np
from env_snake import SnakeEnv
from q_network import QNetwork
from replay_buffer import ReplayBuffer
from dqn_agent import DQNAgent
from torch.utils.tensorboard import SummaryWriter
from collections import deque
writer = SummaryWriter("runs/snake_dqn")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 环境和参数
env = SnakeEnv()
state_dim = 10 # [danger_left, danger_front, danger_right, head_x, head_y, food_x, food_y, direction]
action_dim = 4 # 上下左右的 Q 值
max_episodes = 1000 # 总训练回合数
max_steps = 2000 # 每个回合最多步数
target_update_freq = 20 # 每 N 轮同步一次目标网络
buffer_capacity = 10000 # 缓存数据量
batch_size = 64 # 每个批次数量
# 初始化缓存和 DQNAgent
replay_buffer = ReplayBuffer(buffer_capacity)
agent = DQNAgent(state_dim, action_dim, device, batch_size=batch_size)
# 加载模型参数(续训用)
model_path = 'dqn_snake.pth'
if os.path.exists(model_path):
agent.q_net.load_state_dict(torch.load(model_path)) # 恢复主网络参数
agent.target_q_net.load_state_dict(agent.q_net.state_dict()) # 同步目标网络
print(f"模型已从 {model_path} 加载,继续训练...")
# 每回合蛇在移动吃到食物等,的奖励记录
episode_rewards = []
# 游戏内得分
score_window = []
# 滑动平均窗口:计算最近 100 回合(episode)蛇的平均奖励(reward)
reward_window = deque(maxlen=100)
# 训练过程中蛇从头到尾走过的“总步数”,它用于记录训练进度,并用于 TensorBoard 可视化等用途
global_step = 0
for episode in range(1, max_episodes + 1):
state = env.reset()
total_reward = 0
# 每回合蛇走 max_steps 多步
for step in range(max_steps):
# 前 200 回合(episode 1 ~ 200)内,强制只用随机动作,不使用模型预测的 Q 值动作,即完全靠探索(exploration),不做“利用”(exploitation)。
force_random = episode <= 200
action = agent.select_action(state, force_random=force_random)
next_state, reward, done, _ = env.step(action)
# env.render() # 显示游戏
# 存入经验池
replay_buffer.push(state, action, reward, next_state, done)
state = next_state
total_reward += reward
# 更新模型
if episode > 200:
agent.update(replay_buffer, writer, global_step)
global_step += 1
if done:
break
episode_rewards.append(total_reward)
# 每隔 N 回合更新目标网络
if episode % target_update_freq == 0:
agent.update_target()
# 收到衰减 ε-greedy
# 每个 episode 结束时衰减 epsilon
# 对 ε-greedy 策略中的 ε(探索率)进行指数衰减
# 随着训练进行,逐渐减少“随机探索”的概率,增加“根据模型选择动作”的概率。
# agent.epsilon = max(agent.epsilon_end, agent.epsilon * 0.995)
# 输出进度
print(f"Episode {episode} | Total Reward: {total_reward:.2f} | Epsilon: {agent.epsilon:.4f}")
# ===== TensorBoard 日志记录 =====
writer.add_scalar("epsilon", agent.epsilon, episode)
writer.add_scalar("reward/total", total_reward, episode)
reward_window.append(total_reward)
writer.add_scalar("reward/mean_100", np.mean(reward_window), episode)
score_window.append(env.score)
writer.add_scalar("score/mean_100", np.mean(score_window), episode)
# 保存模型参数
torch.save(agent.q_net.state_dict(), model_path)
print("训练完成,模型已保存!")
# 关闭环境
env.close()
writer.close()
六、测试
1. 代码
import torch
import time
from env_snake import SnakeEnv
from q_network import QNetwork
model_path = './dqn_snake.pth'
state_dim = 10
action_dim = 4
device = 'cpu'
env = SnakeEnv()
model = QNetwork(state_dim, action_dim).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
state = env.reset()
total_reward = 0
while True:
# 转为 tensor
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
# 预测 Q 值并选择最优动作
with torch.no_grad():
q_values = model(state_tensor)
action = q_values.argmax().item()
# 执行动作
next_state, reward, done, _ = env.step(action)
total_reward += reward
env.render() # 显示画面
time.sleep(0.05) # 控制游戏速度(可调)
state = next_state
if done:
print(f"游戏结束,总得分: {total_reward:.2f}")
break
env.close()
七、 补充
1. one-hot 向量
1.1. 定义:
One-hot 是一种将类别编号转化为向量的方式,向量长度等于类别总数,其中:
- 某一类的位置为 1
- 其他所有位置为 0
1.2. 举例:
假设有 4 个类别: [UP, DOWN, LEFT, RIGHT](上、下、左、右)
它们的 one-hot 表示如下:
| 方向 | One-hot 向量 |
|---|---|
| UP | [1, 0, 0, 0] |
| DOWN | [0, 1, 0, 0] |
| LEFT | [0, 0, 1, 0] |
| RIGHT | [0, 0, 0, 1] |
2. ε-greedy(epsilon-greedy)策略
2.1. 探索率 ε(epsilon)
探索率 ε(epsilon)是强化学习中最常用的一种探索(exploration)+ 利用(exploitation)策略,用来决定“下一步该做什么动作”
2.2. 规则
在每一步选择动作时:
- 以 概率 ε(epsilon):随机选择一个动作(探索新可能)
- 以 概率 1−ε:选择当前 Q 网络预测的最优动作(利用已有知识)
2.3. 示例
当前 Q 网络输出:
| 动作 | Q值 |
|---|---|
| 0 | 0.2 |
| 1 | 0.4 |
| 2 | 0.1 |
| 3 | 0.9 ← Q 值最大,理论上是最优 |
ε = 0.1 的策略含义:
-
1−ε = 90%的概率 → 选 Q 值最大的动作(动作3) -
ε = 10%的概率 → 随机选一个动作(动作0、1、2、3)
3. tensorboard 记录并实时查看训练过程中数据的模块
运行
tensorboard --logdir=runs