PyTorch 全连接神经网络 + DQN 训练贪吃蛇

在本项目中,通过深度 Q 学习(Deep Q-Learning, DQN)训练一个自动玩贪吃蛇游戏的模型。训练过程融合了现代深度学习与强化学习的多种关键技术,如全连接神经网络、循环神经网络(RNN)、长短期记忆网络(LSTM)、经验回放、ε-greedy 策略等。

一. 环境建模:SnakeEnv

1. 代码 env_snake.py

  • 说明(和模型相关)
    • _get_state() 方法是生成神经网络需要的特征。这里生成一个10 维向量。每个向量的意义代码注释
    • step() 方法本身是在判断蛇下一步的状态(是否撞墙,吃到食物,计算分数等);step() 方法中 reward 是在给模型的操作进行打分(吃到食物加分,靠近食物加减分,存活没吃到食物的加减分)
import random
import pygame
import numpy as np

class SnakeEnv:
    def __init__(self, width=800, height=600, cell_size=20):
        pygame.init() # 初始化pygame模块
        self.width = width # 游戏窗口宽度
        self.height = height # 游戏窗口高度
        self.cell_size = cell_size # 每个单元格的像素大小
        self.map_width = width // cell_size # 地图宽度(格子数)
        self.map_height = height // cell_size # 地图高度(格子数)

        self.screen = pygame.display.set_mode((self.width, self.height))  # 创建窗口
        pygame.display.set_caption('RL Snake')  # 设置窗口标题
        self.clock = pygame.time.Clock()  # 控制帧率
        self.snake_speed = 15  # 游戏速度(帧率)

        # 定义方向常量
        self.UP = 0
        self.DOWN = 1
        self.LEFT = 2
        self.RIGHT = 3
        self.HEAD = 0  # 蛇头在列表中的索引

        self.reset()  # 初始化游戏状态

    def reset(self):
        # 初始化蛇的位置
        startx = random.randint(3, self.map_width - 8)
        starty = random.randint(3, self.map_height - 8)
        self.snake = [{'x': startx, 'y': starty},
                      {'x': startx - 1, 'y': starty},
                      {'x': startx - 2, 'y': starty}]
        self.direction = self.RIGHT  # 初始方向向右
        self.food = self._random_food()  # 随机放置食物
        self.score = 0  # 初始得分
        self.done = False  # 游戏是否结束
        return self._get_state()  # 返回初始状态向量

    def _random_food(self):
        # 在地图随机生成一个不与蛇身体重叠的食物
        while True:
            pos = {'x': random.randint(0, self.map_width - 1),
                   'y': random.randint(0, self.map_height - 1)}
            if pos not in self.snake:
                return pos

    def _get_state(self):
        # 构建10维状态向量
        head = self.snake[self.HEAD]

        # 当前方向与坐标偏移映射
        direction_map = {
            self.UP: (0, -1),     # 向上走,x不变,y减1
            self.DOWN: (0, 1),    # 向下走,x不变,y加1
            self.LEFT: (-1, 0),   # 向左走,x减1,y不变
            self.RIGHT: (1, 0)    # 向右走,x加1,y不变
        }

        # 当前方向对应的偏移
        # 图像坐标系的原点在左上角。
        dx, dy = direction_map[self.direction]
        # 左边方向的偏移 = 逆时针旋转 90°
        left_dx, left_dy = -dy, dx
        # 右边方向的偏移 = 顺时针旋转 90°
        right_dx, right_dy = dy, -dx

        # 判断某个位置是否危险(撞墙或撞自己)
        def is_danger(x, y):
            if x < 0 or x >= self.map_width or y < 0 or y >= self.map_height:
                return 1.0
            for segment in self.snake[1:]:
                if segment['x'] == x and segment['y'] == y:
                    return 1.0
            return 0.0

        # 判断前、左、右是否危险
        danger_front = is_danger(head['x'] + dx, head['y'] + dy)
        danger_left = is_danger(head['x'] + left_dx, head['y'] + left_dy)
        danger_right = is_danger(head['x'] + right_dx, head['y'] + right_dy)

        # 判断食物方向是否在前、左、右
        food_dx = self.food['x'] - head['x']
        food_dy = self.food['y'] - head['y']

        def food_dir_is(dx1, dy1):
            return 1.0 if (np.sign(food_dx) == dx1 and np.sign(food_dy) == dy1) else 0.0

        food_front = food_dir_is(dx, dy)
        food_left = food_dir_is(left_dx, left_dy)
        food_right = food_dir_is(right_dx, right_dy)

        # 当前移动方向的 one-hot 向量(上、下、左、右)
        dir_onehot = [0.0, 0.0, 0.0, 0.0]
        dir_onehot[self.direction] = 1.0

        # 最终状态向量(10维)
        state = [
            danger_left, danger_front, danger_right,
            food_left, food_front, food_right,
            *dir_onehot
        ]
        return np.array(state, dtype=np.float32)

    def step(self, action):
        if self.done:
            return self._get_state(), 0, True, {}

        # 更新方向(避免直接反向)
        if (action == self.UP and self.direction != self.DOWN) or \
           (action == self.DOWN and self.direction != self.UP) or \
           (action == self.LEFT and self.direction != self.RIGHT) or \
           (action == self.RIGHT and self.direction != self.LEFT):
            self.direction = action

        # 记录移动前蛇头位置(用于奖励计算)
        head = self.snake[self.HEAD]
        prev_dist = abs(head['x'] - self.food['x']) + abs(head['y'] - self.food['y'])

        # 移动蛇
        self._move()

        reward = 0

        # 判断游戏是否结束
        is_alive, crash = self._is_alive()
        if not is_alive:
            self.done = True
            if crash == 'wall':
                reward -= 10
            else:
                reward -= 20
        else:

            # 计算新位置距离
            new_head = self.snake[self.HEAD]
            new_dist = abs(new_head['x'] - self.food['x']) + abs(new_head['y'] - self.food['y'])

            # 奖励蛇“靠近”食物
            reward += (prev_dist - new_dist) * 0.5

            # 距离越近奖励越高(即使没吃到)
            reward += 1.0 / (new_dist + 1)

            if self._is_eat_food():  # 吃到食物奖励
                self.score += 1
                reward += 10
                self.food = self._random_food()
            else:
                self.snake.pop()  # 没吃到就删尾(移动)

        return self._get_state(), reward, self.done, {}

    def _move(self):
        # 根据方向添加新的蛇头位置
        head = self.snake[self.HEAD]
        if self.direction == self.UP:
            new_head = {'x': head['x'], 'y': head['y'] - 1}
        elif self.direction == self.DOWN:
            new_head = {'x': head['x'], 'y': head['y'] + 1}
        elif self.direction == self.LEFT:
            new_head = {'x': head['x'] - 1, 'y': head['y']}
        elif self.direction == self.RIGHT:
            new_head = {'x': head['x'] + 1, 'y': head['y']}
        self.snake.insert(0, new_head)

    def _is_alive(self):
        # 判断是否撞墙或撞自己
        head = self.snake[self.HEAD]
        if head['x'] < 0 or head['x'] >= self.map_width or \
           head['y'] < 0 or head['y'] >= self.map_height:
            return False, 'wall'
        for segment in self.snake[1:]:
            if segment['x'] == head['x'] and segment['y'] == head['y']:
                return False, 'snake'
        return True, ''

    def _is_eat_food(self):
        head = self.snake[self.HEAD]
        return head['x'] == self.food['x'] and head['y'] == self.food['y']

    def render(self):

        for event in pygame.event.get():  # 防止窗口卡死
            if event.type == pygame.QUIT:
                self.close()

        self.screen.fill((0, 0, 0))  # 背景黑色

        # 画食物(红色)
        fx, fy = self.food['x'] * self.cell_size, self.food['y'] * self.cell_size
        pygame.draw.rect(self.screen, (255, 0, 0), (fx, fy, self.cell_size, self.cell_size))

        # 画蛇身体(深青色)
        for seg in self.snake:
            x, y = seg['x'] * self.cell_size, seg['y'] * self.cell_size
            pygame.draw.rect(self.screen, (0, 139, 139), (x, y, self.cell_size, self.cell_size))

        pygame.display.flip()
        self.clock.tick(self.snake_speed)

    def close(self):
        pygame.quit()



二、全连接神经网络

将池化后提取到的特征信息映射为最终的输出类别或回归结果,是神经网络的决策层

1. 代码 q_network.py

这里是采用的三层全连层。一般采用层数如下:

层数(隐藏层) 特点 适用情况
1 层(最浅) 只能学习“非常简单”的函数 小数据量,小任务(线性或近似线性)
2~3 层 能处理常见复杂映射,训练速度快 中小规模任务(如强化学习)
4~6 层 表达能力强,能学到更复杂结构 图像、文本等特征模式复杂任务
10 层以上 深度网络,可能出现梯度消失或过拟合问题 用残差连接、BN 等技术辅助才好训练
import torch
import torch.nn as nn
import torch.nn.functional as F


# 这是一个简单的多层感知机,用于输入 10 维状态,输出 4 个动作的 Q 值(上、下、左、右):
# 输入是 [danger_left, danger_front, danger_right,food_dir_left, food_dir_front, food_dir_right,current_direction (one-hot)]
# 输出是 [Q(上), Q(下), Q(左), Q(右)]
class QNetwork(nn.Module):
    def __init__(self, state_dim=10, action_dim=4, hidden_dim=128):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))       # 第1隐藏层
        x = F.relu(self.fc2(x))       # 第2隐藏层
        x = F.relu(self.fc3(x))       # 第3隐藏层
        return self.out(x)            # 输出每个动作的Q值


if __name__ == '__main__':
    # 初始化模型
    model = QNetwork()

    # 构造一个假的 state 输入(比如一组10维状态)
    dummy_state = torch.randn(1, 10)  # shape = [batch=1, state_dim=10]

    # 执行前向传播
    q_values = model(dummy_state)

    # 打印结果
    print("Q-values for 4 actions:", q_values)

三、Q-Learning 与 DQN 简介

1. 简介

1.1. Q-Learning

Q-Learning 是一种无模型的强化学习方法,用来学习 "状态-动作" 价值函数 Q(s, a),从而制定最优策略。

1.2. DQN

DQN 是 Q-Learning 的深度学习版本,用神经网络近似 Q 函数:

Q(s, a) \approx \text{NeuralNetwork}(s)[a]

在训练中,我们使用目标 Q 值:(Q 值表示:在状态 s 下做动作 a,未来可能获得的“总奖励”有多少)

Q_{\text{target}} = r + \gamma \cdot \max_{a'} Q'(s', a')

通过最小化预测 Q 值和目标 Q 值之间的均方误差(MSE)进行学习。

DQN 使用两个网络:

  • 主网络:用于预测当前动作的 Q 值并更新
  • 目标网络:延迟更新,用于计算目标 Q 值

每隔固定步数将主网络参数拷贝到目标网络,以减少震荡。

2. 代码 dqn_aget.py

DQNAgent 是 DQN 强化学习算法中的智能体类(Agent),它负责从环境中学习、预测动作、优化模型,是整个智能决策系统的“大脑”。

2.1. 核心职责
职责 作用
选择动作 select_action() 根据当前状态,预测每个动作的 Q 值,并根据 ε-greedy 策略选出最终动作。 关于 ε-greedy 看最后面的补充
学习优化 update() 从经验回放中采样,更新 Q 网络参数。
同步目标网络 update_target() 定期将主 Q 网络的参数复制给目标网络。
维护 ε-greedy 策略 控制探索与利用之间的平衡(ε 随时间衰减)。
import torch
import torch.nn.functional as F
import numpy as np
from q_network import QNetwork


class DQNAgent:
    # state_dim: 输入维度[
    #                          danger_left, danger_front, danger_right,     # 蛇前方是否有危险(墙或身体)
    #                          food_dir_left, food_dir_front, food_dir_right,  # 食物在蛇前、左、右的方向
    #                          current_direction (one-hot)                   # 当前移动方向(如 [0,0,1,0] 表示左)
    #                    ]
    # action_dim: 输出维度: [Q(上), Q(下), Q(左), Q(右)]
    #
    # gamma: 折扣因; 说明:总奖励 = 当前奖励 + 将来奖励的折扣和
    #
    # ε-greedy 策略参数
    # epsilon_start: 初始探索率,训练刚开始时有多少概率“瞎试”动作(通常设为 1.0)
    # epsilon_end: 最小探索率,训练后期保留多少“探索”(一般在 0.01 ~ 0.1 之间)
    #              目的:即使训练后期,也保留一部分随机性以避免“陷入局部最优”。
    # epsilon_decay: 衰减速度,控制探索率下降的“步数尺度”. 从 epsilon_start 到 epsilon_end 的步数
    #                越小 → 衰减越快 → 越快开始利用(训练速度快,但可能不稳)
    #                越大 → 衰减越慢 → 更长时间探索(训练稳但慢)
    def __init__(self, state_dim, action_dim, device,
                 gamma=0.99, lr=1e-3, batch_size=64,
                 epsilon_start=1.0, epsilon_end=0.05, epsilon_decay=2000):        # 状态维度(输入维度)
        self.state_dim = state_dim
        # 动作数量(输出维度)
        self.action_dim = action_dim
        # 运行在哪个设备(CPU或GPU)
        self.device = device

        # 初始化主网络 Q(s, a)
        self.q_net = QNetwork(state_dim, action_dim).to(device)
        # 初始化目标网络 Q_target(s, a)
        self.target_q_net = QNetwork(state_dim, action_dim).to(device)
        # 目标网络参数拷贝自主网络
        self.target_q_net.load_state_dict(self.q_net.state_dict())
        self.target_q_net.eval()  # 不参与训练,只用于预测

        # 优化器:Adam
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        # 折扣因子 gamma
        self.gamma = gamma
        # 每次训练使用的样本数量
        self.batch_size = batch_size

        # ε-greedy 策略参数
        # 当前的探索率
        # 在 ε-greedy 策略中,每次选择动作时,有 epsilon 的概率随机选一个动作,而不是选 Q 值最大的动作。
        self.epsilon_start = epsilon_start  # 保存初始值
        self.epsilon = epsilon_start       # 当前 epsilon

        # 探索率的最小值,即训练进行到后期后,不会继续减少探索了,保持一个最低限度的随机性。
        # 保持少量随机性可以避免模型陷入局部最优
        self.epsilon_end = epsilon_end

        # 探索率衰减的“时间尺度”——训练进行多少步后 epsilon 会下降到接近 epsilon_end。
        self.epsilon_decay = epsilon_decay

        # 用于控制探索率 ε(epsilon)的衰减速度
        # 训练过程中递增的步数计数器,每调用一次 select_action(),它就加 1
        self.learn_step = 0

    # 选择动作(带ε-greedy策略)
    def select_action(self, state, force_random=False):

        # 强制只用随机动作,不使用模型预测的 Q 值动作。时不记录动作
        if force_random:
            return np.random.randint(self.action_dim)

        # 每次调用就增加 learn 步数
        # 记录了“模型已经学习(或者决策)了多少步”
        self.learn_step += 1

        # 衰减后的 epsilon(指数下降)
        # 它的作用是让 ε(探索概率)随着训练步数 learn_step 逐步 从初始值 ε_start 衰减到 ε_end
        self.epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * \
                       np.exp(-1. * self.learn_step / self.epsilon_decay)

        if np.random.rand() < self.epsilon:
            # 随机选择动作(探索)
            return np.random.randint(self.action_dim)
        else:
            # 模型预测 Q 值,选择最大 Q 的动作(利用)
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device)

            # 不计算梯度
            # 这里是 select_action 选择动作,这里是在做推理
            with torch.no_grad():
                q_values = self.q_net(state_tensor)
            #  print(f"Q-values: {q_values.cpu().numpy().squeeze()}")
            return q_values.argmax().item()

    # 从经验池中采样并更新网络
    def update(self, replay_buffer, writer, global_step=None):
        if len(replay_buffer) < self.batch_size:
            return  # 如果经验不够,就跳过更新

        # 从经验池中随机采样一个 batch = 64
        states, actions, rewards, next_states, dones = replay_buffer.sample(self.batch_size)

        # 将数据转移到对应设备
        states = states.to(self.device)
        actions = actions.to(self.device)
        rewards = rewards.to(self.device)
        next_states = next_states.to(self.device)
        dones = dones.to(self.device)

        # DQN 核心的训练更新逻辑
        # 用当前网络预测的 Q 值和目标 Q 值计算差距(损失),并反向传播优化模型。
        #
        # q_values当前网络预测 Q 值:Q(s, a)
        # self.q_net(states) → 得到每个状态对应的所有动作的 Q 值(shape: [batch_size, action_dim]);
        # actions.unsqueeze(1) → 把每个 action 的索引变成 shape [batch_size, 1],以便用于 gather;
        # .gather(1, ...) → 提取当前经验中实际采取的动作 a 的 Q 值 Q(s,a);
        # .squeeze(1) → 去掉维度 [batch_size, 1] → [batch_size],得到一维向量。
        q_values = self.q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

        # target_q 目标 Q 值:
        with torch.no_grad():
            # self.target_q_net(next_states):预测下一步状态 s' 所有动作的 Q 值。
            # .max(1)[0]:对每个状态,选出最大 Q 值 max_a Q(s', a') → 目标值的关键
            max_next_q = self.target_q_net(next_states).max(1)[0]
            # DQN 的 目标 Q 值方程: r + γ * max(Q_target(s', a'))
            target_q = rewards + self.gamma * max_next_q * (1 - dones)

            max_q = q_values.max(0)[0].mean().item()   # 每个样本的最大Q,再求平均
            mean_q = q_values.mean().item()            # 所有 Q 值的平均

        # 均方误差损失函数
        loss = F.mse_loss(q_values, target_q)

        # 记录曲线图
        if writer is not None and global_step is not None:
            writer.add_scalar("q_value/max", max_q, global_step)
            writer.add_scalar("q_value/mean", mean_q, global_step)
            writer.add_scalar("loss", loss.item(), global_step)

        # 反向传播更新 Q 网络参数
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    # 用于每隔固定步数同步目标网络
    def update_target(self):
        self.target_q_net.load_state_dict(self.q_net.state_dict())

四、Replay Buffer(经验回放机制)

1. 说明

1.1. 作用

Replay Buffer 是一个存放与模型交互经验的“记忆库”,训练时从中随机采样小批量经验来打破数据之间的相关性,提升学习效率和稳定性。

1.2. 什么是 Replay Buffe

是一个 环形队列(FIFO),每次环境交互都会存一条“经验”进去可以理解为一个 Python 列表或队列结构,最多存 N 条最新经验

Replay Buffer 是一个存放与模型交互经验的“记忆库”,训练时从中随机采样小批量经验来打破数据之间的相关性,提升学习效率和稳定性。没有它,DQN 很容易训练不稳定甚至崩溃

1.3. 为什么需要

如果模型智不断地与环境交互获得数据,如果我们直接拿“刚刚经历的那条数据”来训练,会有两个大问题

  • 数据高度相关:连续的状态和动作之间高度相关(比如连续帧图像)
  • 数据分布不断变化:模型边训练边收集数据,会导致训练目标也在变

2. 代码 replay_buffer.py

#! coding=utf-8

from collections import deque
import random
import numpy as np
import torch


class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        # state 是list,[
        #                          danger_left, danger_front, danger_right,     # 蛇前方是否有危险(墙或身体)
        #                          food_dir_left, food_dir_front, food_dir_right,  # 食物在蛇前、左、右的方向
        #                          current_direction (one-hot)                   # 当前移动方向(如 [0,0,1,0] 表示左)
        #                    ]
        # action 是随机动作
        # reward 奖励(如吃到食物 +1,死亡 -1)
        # next_state 下一个状态
        state = np.array(state, dtype=np.float32)
        next_state = np.array(next_state, dtype=np.float32)
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
        return (
            torch.tensor(states, dtype=torch.float32),
            torch.tensor(actions, dtype=torch.long),
            torch.tensor(rewards, dtype=torch.float32),
            torch.tensor(next_states, dtype=torch.float32),
            torch.tensor(dones, dtype=torch.float32),
        )

    def __len__(self):
        return len(self.buffer)

五、训练

1. 代码 train_dqn.py

import torch
import os
import numpy as np
from env_snake import SnakeEnv
from q_network import QNetwork
from replay_buffer import ReplayBuffer
from dqn_agent import DQNAgent
from torch.utils.tensorboard import SummaryWriter
from collections import deque


writer = SummaryWriter("runs/snake_dqn")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 环境和参数
env = SnakeEnv()
state_dim = 10           # [danger_left, danger_front, danger_right, head_x, head_y, food_x, food_y, direction]
action_dim = 4           # 上下左右的 Q 值
max_episodes = 1000      # 总训练回合数
max_steps = 2000         # 每个回合最多步数
target_update_freq = 20  # 每 N 轮同步一次目标网络
buffer_capacity = 10000  # 缓存数据量
batch_size = 64          # 每个批次数量

# 初始化缓存和 DQNAgent
replay_buffer = ReplayBuffer(buffer_capacity)
agent = DQNAgent(state_dim, action_dim, device, batch_size=batch_size)

# 加载模型参数(续训用)
model_path = 'dqn_snake.pth'
if os.path.exists(model_path):
    agent.q_net.load_state_dict(torch.load(model_path))        # 恢复主网络参数
    agent.target_q_net.load_state_dict(agent.q_net.state_dict())  # 同步目标网络
    print(f"模型已从 {model_path} 加载,继续训练...")

# 每回合蛇在移动吃到食物等,的奖励记录
episode_rewards = []
# 游戏内得分
score_window = []

# 滑动平均窗口:计算最近 100 回合(episode)蛇的平均奖励(reward)
reward_window = deque(maxlen=100)

# 训练过程中蛇从头到尾走过的“总步数”,它用于记录训练进度,并用于 TensorBoard 可视化等用途
global_step = 0

for episode in range(1, max_episodes + 1):
    state = env.reset()
    total_reward = 0

    # 每回合蛇走 max_steps 多步
    for step in range(max_steps):
        # 前 200 回合(episode 1 ~ 200)内,强制只用随机动作,不使用模型预测的 Q 值动作,即完全靠探索(exploration),不做“利用”(exploitation)。
        force_random = episode <= 200
        action = agent.select_action(state, force_random=force_random)
        next_state, reward, done, _ = env.step(action)
        #  env.render()  # 显示游戏
        # 存入经验池
        replay_buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward

        # 更新模型
        if episode > 200:
            agent.update(replay_buffer, writer, global_step)
        global_step += 1
        if done:
            break

    episode_rewards.append(total_reward)

    # 每隔 N 回合更新目标网络
    if episode % target_update_freq == 0:
        agent.update_target()

    # 收到衰减 ε-greedy
    # 每个 episode 结束时衰减 epsilon
    # 对 ε-greedy 策略中的 ε(探索率)进行指数衰减
    # 随着训练进行,逐渐减少“随机探索”的概率,增加“根据模型选择动作”的概率。
    #  agent.epsilon = max(agent.epsilon_end, agent.epsilon * 0.995)

    # 输出进度
    print(f"Episode {episode} | Total Reward: {total_reward:.2f} | Epsilon: {agent.epsilon:.4f}")

    # ===== TensorBoard 日志记录 =====
    writer.add_scalar("epsilon", agent.epsilon, episode)
    writer.add_scalar("reward/total", total_reward, episode)

    reward_window.append(total_reward)
    writer.add_scalar("reward/mean_100", np.mean(reward_window), episode)

    score_window.append(env.score)
    writer.add_scalar("score/mean_100", np.mean(score_window), episode)

# 保存模型参数
torch.save(agent.q_net.state_dict(), model_path)
print("训练完成,模型已保存!")

# 关闭环境
env.close()
writer.close()

六、测试

1. 代码

import torch
import time
from env_snake import SnakeEnv
from q_network import QNetwork


model_path = './dqn_snake.pth'
state_dim = 10
action_dim = 4
device = 'cpu'

env = SnakeEnv()
model = QNetwork(state_dim, action_dim).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


state = env.reset()
total_reward = 0

while True:
    # 转为 tensor
    state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)

    # 预测 Q 值并选择最优动作
    with torch.no_grad():
        q_values = model(state_tensor)
        action = q_values.argmax().item()

    # 执行动作
    next_state, reward, done, _ = env.step(action)
    total_reward += reward

    env.render()  # 显示画面
    time.sleep(0.05)  # 控制游戏速度(可调)

    state = next_state
    if done:
        print(f"游戏结束,总得分: {total_reward:.2f}")
        break

env.close()

七、 补充

1. one-hot 向量

1.1. 定义

One-hot 是一种将类别编号转化为向量的方式,向量长度等于类别总数,其中:

  • 某一类的位置为 1
  • 其他所有位置为 0
1.2. 举例

假设有 4 个类别: [UP, DOWN, LEFT, RIGHT](上、下、左、右)
它们的 one-hot 表示如下:

方向 One-hot 向量
UP [1, 0, 0, 0]
DOWN [0, 1, 0, 0]
LEFT [0, 0, 1, 0]
RIGHT [0, 0, 0, 1]

2. ε-greedy(epsilon-greedy)策略

2.1. 探索率 ε(epsilon)

探索率 ε(epsilon)是强化学习中最常用的一种探索(exploration)+ 利用(exploitation)策略,用来决定“下一步该做什么动作”

2.2. 规则

在每一步选择动作时:

  • 概率 ε(epsilon):随机选择一个动作(探索新可能)
  • 概率 1−ε:选择当前 Q 网络预测的最优动作(利用已有知识)
2.3. 示例

当前 Q 网络输出:

动作 Q值
0 0.2
1 0.4
2 0.1
3 0.9 ← Q 值最大,理论上是最优

ε = 0.1 的策略含义:

  • 1−ε = 90% 的概率 → 选 Q 值最大的动作(动作3)
  • ε = 10% 的概率 → 随机选一个动作(动作0、1、2、3)

3. tensorboard 记录并实时查看训练过程中数据的模块

运行
tensorboard --logdir=runs
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容