Sarsa
Sarsa原理
Sarsa的决策过程和Q-Learning类似,都是在Q表中挑选值较大的动作值施加在环境中来换取奖惩。不同之处在于更新方式。
如下图所示,在状态s采取行动a,到达下一个状态s'时,Q-Learning算法会根据Q(s')的最大值,假设自己走能使maxQ(s')这条路,来更新刚刚走过的Q(s,a)。此时在s‘的agent还没做出任何决策。
与Q-Learning不同,Sarsa会在状态s’做出实际的行为(该行为并不一定能使Q(s')最大化),并根据实际做出行为的Q值来更新刚刚走过的Q(s,a)。
Sarsa是一种on-policy在线学习算法,Q-learning是一种off-policy离线学习算法。
Q-Learning更新状态s时只看到maxQ(s'),忽视掉该行为可能带来的惩罚,因此它是一个大胆、贪婪的策略。Sarsa算法在接近收敛时,允许对探索性的行动进行可能的惩罚(Q-Learning会直接忽略)这使得Sarsa算法更加保守。
Sarsa算法更新
还是agent走迷宫的例子
与Q-Learning的不同在于
- 两个choose_action,第一个在循环外面(真的做出了行为并刷新了环境
- RL.learn(str(observation), action, reward, str(observation_), action_)这里多了一个action_
def update():
for episode in range(100):
# 初始化环境
observation = env.reset()
# Sarsa 根据 state 观测选择行为
action = RL.choose_action(str(observation))
while True:
# 刷新环境
env.render()
# 在环境中采取行为, 获得下一个 state_ (obervation_), reward, 和是否终止
observation_, reward, done = env.step(action)
# 根据下一个 state (obervation_) 选取下一个 action_
action_ = RL.choose_action(str(observation_))
# 从 (s, a, r, s, a) 中学习, 更新 Q_tabel 的参数 ==> Sarsa
RL.learn(str(observation), action, reward, str(observation_), action_)
# 将下一个当成下一步的 state (observation) and action
observation = observation_
action = action_
# 终止时跳出循环
if done:
break
# 大循环完毕
print('game over')
env.destroy()
if __name__ == "__main__":
env = Maze()
RL = SarsaTable(actions=list(range(env.n_actions)))
env.after(100, update)
env.mainloop()
Sarsa思维决策
定义一个RL父类
import numpy as np
import pandas as pd
class RL(object):
def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
... # 和 QLearningTable 中的代码一样
def check_state_exist(self, state):
... # 和 QLearningTable 中的代码一样
def choose_action(self, observation):
... # 和 QLearningTable 中的代码一样
def learn(self, *args):
pass # 每种的都有点不同, 所以用 pass
定义Q-Learning子类
class QLearningTable(RL): # 继承了父类 RL
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示继承关系
def learn(self, s, a, r, s_): # learn 的方法在每种类型中有不一样, 需重新定义
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, :].max()
else:
q_target = r
self.q_table.loc[s, a] += self.lr * (q_target - q_predict)
定义Sarsa子类
class SarsaTable(RL): # 继承 RL class
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示继承关系
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # q_target 基于选好的 a_ 而不是 Q(s_) 的最大值
else:
q_target = r # 如果 s_ 是终止符
self.q_table.loc[s, a] += self.lr * (q_target - q_predict) # 更新 q_table
Sarsa(lambda)
Sarsa(lambda)是一种Sarsa的提速算法
上述Sarsa算法是单步更新法,即每次获取到reward,只更新获取到reward的前一步
Sarsa-lambda就是更新获取到reward的前lambda步
lambda就是一个衰变值,取值在0-1之间。当lambda取0,就变成了Sarsa的单步更新,当lambda取 1,就变成了回合更新。lambda取值越大,离宝藏越近的步更新力度越大。
Sarsa(lambda)例子
还是那个走迷宫的例子
SarsaLambdaTable在算法更新迭代的部分,和SarsaTable 是一样的,思维决策部分有所不同,如下图所示:
从上图可以看出,和Sarsa相比,Sarsa(lambda)算法中多了一个矩阵E (eligibility trace),它是用来保存在路径中所经历的每一步,因此在每次更新时也会对之前经历的步进行更新
"""
This part of code is the Q learning brain, which is a brain of the agent.
All decisions are made in here.
View more on my tutorial page: https://morvanzhou.github.io/tutorials/
"""
import numpy as np
import pandas as pd
# 预设值里增加了trace_decay=0.9,也就是lambda的值
class SarsaLambdaTable(RL): # 继承RL父类
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):
super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)
# backward view, eligibility trace.
self.lambda_ = trace_decay
self.eligibility_trace = self.q_table.copy()
def check_state_exist(self, state):
if state not in self.q_table.index:
# append new state to q table
to_be_append = pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
self.q_table = self.q_table.append(to_be_append)
# also update eligibility trace
self.eligibility_trace = self.eligibility_trace.append(to_be_append)
def learn(self, s, a, r, s_, a_):
self.check_state_exist(s_)
q_predict = self.q_table.loc[s, a]
if s_ != 'terminal':
q_target = r + self.gamma * self.q_table.loc[s_, a_] # next state is not terminal
else:
q_target = r # next state is terminal
error = q_target - q_predict
# increase trace amount for visited state-action pair
# Method 1:
# self.eligibility_trace.loc[s, a] += 1
# Method 2:
self.eligibility_trace.loc[s, :] *= 0
self.eligibility_trace.loc[s, a] = 1
# Q update
self.q_table += self.lr * error * self.eligibility_trace
# decay eligibility trace after update
self.eligibility_trace *= self.gamma*self.lambda_
除了图中和上面代码这种更新方式, 还有一种会更加有效率. 我们可以将上面的这一步替换成下面这样:
# 上面代码中的方式:
self.eligibility_trace.ix[s, a] += 1
# 更有效的方式:
self.eligibility_trace.ix[s, :] *= 0
self.eligibility_trace.ix[s, a] = 1
他们的不同之处可以用这张图来概括:
这是针对于一个 state-action 值按经历次数的变化,最上面是经历 state-action 的时间点,第二张图是使用这种方式所带来的 “不可或缺性值”:
self.eligibility_trace.ix[s, a] += 1
下面图是使用这种方法带来的 “不可或缺性值”:
self.eligibility_trace.ix[s, :] *= 0; self.eligibility_trace.ix[s, a] = 1
最后不要忘了,eligibility trace只是记录每个回合的每一步,新回合(episode)开始的时候需要将 Trace 清零
for episode in range(100):
...
# 新回合, 清零
RL.eligibility_trace *= 0
while True: # 开始回