Prioritized DQN

1.简介

Prioritized DQN 是为了解决当在memory中均匀采样时候学习效率低下的问题。原因主要有两个:

1.我们想让new transition立马用于更新,因为这样的new experience对于explore很重要。

2.我们想让large td-error的transition立马用于更新(比如有99次失败的经历和1次成功的经历,我们希望立马学习这个成功的经历)

显然uniform sampling无法做到这两点。
于是便有了伟大的Prioritized Experience Replay.

论文在这里。
代码在这里。
简单介绍在这里。

下面我将分享自己学习这篇论文的时候一些经验。请读完论文和简单介绍后,如有困惑,再阅读以下部分。

2.关键点

Prioritized DQN能够成功的主要原因有两个:sum tree这种数据结构带来的采样的O(log n)的高效率,和Weighted Importance sampling的正确估计。后者,我现在还没有完全搞明白原理。

我简单由谈下自己对于sum tree数据结构的理解。 sum tree存储的元素是样本的优先级,其思想是根据累积概率密度(因此叫sum)来抽取样本。从最左方开始,优先级累积逐渐增大,如果我们的段>左子孩子,(递归地)就在右子孩子中寻找(这时候要做减法,以便又是新的累积优先级)。

如果把累积优先级(离散地)画出来,我们就会发现,高优先级对应的直线段斜率最大,被抽取到的概率最大。(可以以下图为例,自己在每个段中取数字进行验证)。

sum tree.png

3.代码解读

原代码注释较少,我这里列出几个点,方便大家阅读代码。

  • 代码实现的是DQN, 而不是Double DQN。
  • 在插入new transition更新sum tree的时候, 是根据新样本与原来位置的样本的优先级差来更新。(详见SumTree.add)
  • 在memory中插入new transition的时候,给予new transition最大的优先级,因为我们想让new experience立马用于学习。(详见Memory.store)
  • 在memroy中抽取n个samples后,我们会根据nn计算出来的TD-error来更新那些抽取到的样本的优先级,这样的话new transition就不会一直被学习。(详见Memory.batch_update)。

大家最好照着源码自己敲一编(时间大概2~3小时),我这里给出自己在搬砖过程中写的一点注释(也可以自己下载,照着看)。

import numpy as np
import tensorflow as tf

np.random.seed(1)
tf.set_random_seed(1)


class SumTree(object):
    data_pointer = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        # [------------ Parent nodes -------------][------ leaves to recode priority ----------]
        #            size: capacity - 1                  size: capacity
        self.data = np.zeros(capacity, dtype=object)
        # [------------ data frame ---------------]
        #            size: capacity

    # memory store_transition的时候使用
    def add(self, p, data):                     # p is the new priority, data is transition
 
    # memory batch_update的时候使用
    def update(self, tree_idx, p):

    # memory 分段采样的时候使用
    def get_leaf(self, v):
        parent_idx = 0
        while True:
            cl_idx = 2 * parent_idx + 1
            cr_idx = cl_idx + 1
            if cl_idx >= len(self.tree):   # 此时parent就是叶子结点
                leaf_idx = parent_idx
                break
            else:
                if v <= self.tree[cl_idx]:  # <= 左子孩子,就向左前进
                    parent_idx = cl_idx
                else:
                    v -= self.tree[cl_idx]  # > 右子孩子,需要重新当作一颗累积树,因此要减去左子孩子的值
                    parent_idx = cr_idx

        data_idx = leaf_idx - (self.capacity + 1)
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_p(self):
        return self.tree[0]   # the root


class Memory(object):
    epsilon = 0.01     # small amount to avoid zero priority
    alpha = 0.6         # [0, 1] convet the importance of TD error to priority
    beta = 0.4          # importance sampling, from intial value increasing to 1
    beta_increment_per_sampling = 0.001
    abs_err_upper = 1   # clipped abs error

    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    # 向sum tree的transitions 中加入 new transition
    def store(self, transition):

    # 从sum tree中采取n个样本
    def sample(self, n):

    # 更新采样过的样本的priority(基于abs_error)
    def batch_update(self, tree_idx, abs_error):
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 背景 一年多以前我在知乎上答了有关LeetCode的问题, 分享了一些自己做题目的经验。 张土汪:刷leetcod...
    土汪阅读 14,351评论 0 33
  • Android 自定义View的各种姿势1 Activity的显示之ViewRootImpl详解 Activity...
    passiontim阅读 175,588评论 25 709
  • 风不停地敲打着半掩的窗户,室内凉爽了许多,吃完药虽口苦但胃暖了许多,入眠也就更容易了。 半夜忽被重...
    雪舞冰封阅读 1,670评论 0 1
  • 十里春风透衣裳, 穿过风雨的忧伤, 卸下负累的行囊。 在芳菲深处与你邂逅流年暗香, 暖阳扫枯黄一场风花梦蝶的乐章,...
    陶韵阅读 910评论 0 0
  • 阳光正好 微风不燥 你刚好微笑 我恰好爱上
    锑锅盖盖儿阅读 2,271评论 0 0