强化学习之路2——Sarsa算法

Sarsa算法

最近在学强化学习,看了不少的教程,还是觉得莫烦大神的强化学习教程写的不错。所以,特意仔细研究莫烦的RL代码。在这贴上自己的理解。

莫烦RL教程:https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/

代码:https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents

1.伪代码

下图的是Sarsa算法的伪代码:


Sarsa算法的伪代码

下图对比了Sarsa算法和Q-Learning算法


对比Sarsa算法和Q-Learning算法

2.迷宫游戏——Sarsa算法

2.1Sarsa算法实现

Qlearing和Sarsa更新Q表的不同之处在于,QLearning使用的Q现实是用的Q(S_)中的最大值(下一步不一定使用该(S_,A_),只是想象的),
而Sarsa使用的是下一步将要采用的Q(S_,A_)

# 编写了一个RL父类,方便Q-learning 和 Sarsa 子类继承。
# RL类中前一节都已解读过
class RL(object):
    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = action_space  # a list
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon = e_greedy

        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)

    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # append new state to q table
            self.q_table = self.q_table.append(
                pd.Series(
                    [0]*len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            )

    def choose_action(self, observation):
        self.check_state_exist(observation)
        # action selection
        if np.random.rand() < self.epsilon:
            # choose best action
            state_action = self.q_table.loc[observation, :]
            # some actions may have the same value, randomly choose on in these actions
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # choose random action
            action = np.random.choice(self.actions)
        return action
    # Q-learning 和 Sarsa的learn函数都不一样,所以需要各自编写
    def learn(self, *args):
        pass


# off-policy
class QLearningTable(RL):
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        # super()继承QLearning的父类RL
        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_)
        q_predict = self.q_table.loc[s, a]
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update


# on-policy
class SarsaTable(RL):

    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        # super()继承Sarsa的父类RL
        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)

    def learn(self, s, a, r, s_, a_):
        # 检查S_是否在表中
        self.check_state_exist(s_)
        # Q现实
        q_predict = self.q_table.loc[s, a]
        # Q估计
        if s_ != 'terminal':
            q_target = r + self.gamma * self.q_table.loc[s_, a_]  # next state is not terminal
        else:
            q_target = r  # next state is terminal
        # 更新Q表
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update

2.2主函数

Q-learning是获取S,根据S选择A,使用A得到R和S_。之后使用max(Q(S_))来更新Q(S,A)。
更新使用的Q(S_,A_),下一步时不一定使用,这里只是想象的。

Sarsa是通过S、A,使用A得到R和S_,再根据S_选择A_。这个A_下一步肯定会使用。
哈哈,一个有趣的事,Sarsa使用的(S,A,R,S_,A_),连起来刚好就是Sarsa的拼写。

from maze_env import Maze
from RL_brain import SarsaTable

def update():
    # 训练次数
    for episode in range(100):
        # 初始化环境
        observation = env.reset()

        # 根据state选择action
        action = RL.choose_action(str(observation))

        # 时间步,回合
        while True:

            # 刷新环境
            env.render()

            # 采取action,返回 下一步的状态S_, 奖励R, 游戏结束flag
            observation_, reward, done = env.step(action)

            # 根据S_选择动作
            action_ = RL.choose_action(str(observation_))

            # Sarsa根据(S,A,R,S_,A_)来学习
            RL.learn(str(observation), action, reward, str(observation_), action_)

            # 交换 S and A
            observation = observation_
            action = action_

            # 如果游戏结束,结束本次训练
            if done:
                break

    # end of game
    print('game over')
    env.destroy()

if __name__ == "__main__":
    env = Maze()
    RL = SarsaTable(actions=list(range(env.n_actions)))

    env.after(100, update)
    env.mainloop()

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。