深度强化学习(13)Double DQN - Deep Q Learning (3)

本文主要内容来源于 Berkeley CS285 Deep Reinforcement Learning


在上一章中, 我们介绍了 DQN, 它解决了 Q-Learing 中的2个的问题, 使得Traning 更加稳定。 但是, DQN 本事也存在Overestimate 的问题, 这章,我们会介绍 Double DQN 来解决 DQN 中的问题。

Overestimation

DQN 会过高的估计收益(Q Value), 下图是 DQN (棕色)在4款游戏上的表现。 曲线是训练时估计的 Q Value, 水平线(棕色)是实际值。 从图中可以看出, 训练时 DQN 估计的 Q 值是大大高于实际值的。

Overestimation
原因

问题出在我们利用 Target Network 计算 Target Q Value的时候:我们利用Target Network 选择了 a^{\prime}, 同时,我们又用选择出来的 a^{\prime} 在 Target Network 中计算 Q value。 这意味着, 我们在每一步, 都会选择最大的 Q Value。 但实际上, 我们不太可能在每一步都选择最大的 Q value, Reinforcement Learing 是要最求最大总回报, 而不是每一步都是最大的回报。每一步都要求最大回报,往往是不可能的。

【RL道】在生活中, 我们也有体验,不同方法往往互有利弊,我们只能选择最符合我们当下(State)的方法, 而不可能处处占巧。

【RL道】有时候,我们会遇到一些人鼓吹一些方法论,产品很有效。其实,再他们的宣称中,他们的假设是在每一步都用了最优值,但是, 这种情况不可能发生。一旦应用到现实生活中,收益就会大打折扣。 比如熬夜看书,也许在某些日子你的收益很高,但是你不可能每天都睡4个小时,最后的收益也许还不如正常的作息。

Overestimation Cause

Double DQN

为了解决上面提到的问题, 我们可以把选择 Action 和估计 Q Value 分开使用2个网络。 在DQN 中本来就算有 Target Network 和 Current Netwrok, 所以自然而然的, 我们可以选择用 Target Network 选择 Action, 用 Current Netwrok 计算 Q value。

  • Current Network : Q_{\phi}
  • Target Network : Q_{\phi^{\prime}}
Double DQN
# gradient_steps: 在N步以后, 更新Target Network 的参数
for _ in range(gradient_steps):

    with torch.no_grad():
    
        # Step1 生成一些新的 Transation,加入Replay Buffer
        #       这部分不重要, 先跳过
        pass 
            
        # Step2  从 Replay Buffer 中取一些 Sample
        state, action, state_prime, reward = replay_buffer(sample_size)
        
        # Step3  Compute lable by Target Network 
        # 基于 State Prime, 找出所有可能 Action 的 Q value,
        # 假设 ACTION_SPACE 已知, 类型为 list
        Q_of_next_action = []
        for next_action in ACTION_SPACE:
            Q = target_net(state_prime, next_action)
            Q_of_next_action.append(Q)   
            
        # 找出 Q 值最大的 Action , 作为 Action Prime
        new_action_index = argmax(Q_of_next_action)
        action_prime = ACTION_SPACE[action_prime]
        
        # GAMMA is a hyper-prarameter, normally 0.99
        target_q_values = reward + GAMMA * current_net(state_prime, action_prime)
    
    # Step 4 更新 Current Network 参数
    current_q_values = current_net(state)
    loss = smooth_l1_loss(current_q_values, target_q_values)
    
    # Optimize the policy (Current Network)
    policy.optimizer.zero_grad() 
    loss.backward()

# 在 N 步以后,更新 target_net 参数
target_net.param = current_net.param.copy()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容