强化学习环境:
gym
'CartPole-v1'
from collections import defaultdict
import gym
import numpy as np
import joblib
from pathlib import Path
from random import random
cart_pos_bin = np.linspace(-2.4, 2.4, num=6)[1:-1]
cart_vel_bin = np.linspace(-3, 3, num=4)[1:-1]
pole_ang_bin = np.linspace(-0.21, 0.21, num=8)[1:-1]
pole_vel_bin = np.linspace(-2.0, 2.0, num=6)[1:-1]
def state_coding(observation):
cart_pos = np.digitize([observation[0]], cart_pos_bin)[0]
cart_vel = np.digitize([observation[1]], cart_vel_bin)[0]
pole_ang = np.digitize([observation[2]], pole_ang_bin)[0]
pole_vel = np.digitize([observation[3]], pole_vel_bin)[0]
return (cart_pos, cart_vel, pole_ang, pole_vel)
def choose_action(s):
return 0 if 0< random() < softmax_for_choose(Q_table[s][:2])[0] else 1
def softmax_for_backoff(x):
return np.exp(-x)/np.exp(-x).sum()
def softmax_for_choose(x):
return np.exp(x)/np.exp(x).sum()
def E(s):
qs = Q_table[s][:2]
return qs.dot(softmax_for_backoff(qs))
def update_action_chain(action_chain):
s_last = action_chain[-1][0]
Q_table[s_last] = np.array([0., 0., -5.])
for i in range(len(action_chain)-1)[::-1]:
s, action, reward = action_chain[i]
s_,action_,reward_ = action_chain[i+1]
Q_table[s][action] = reward_ + Q_table[s_][-1]
Q_table[s][-1] = E(s)
def test(n=50,visiable=False):
scores = []
for _ in range(n):
score = 0
observation = env.reset()
s = state_coding(observation)
action = np.argmax(Q_table[s][:2])
while True:
if visiable:env.render()
observation_, _, done, _ = env.step(action)
s_ = state_coding(observation_)
action_ = np.argmax(Q_table[s_][:2])
score += 1
if done:
scores.append(score)
break
s,action= s_,action_
return np.mean(scores)
def train(n=100):
for _ in range(100):
action_chain = []
observation = env.reset()
s = state_coding(observation)
action,reward = choose_action(s) ,1
while True:
observation_, reward_, done, _ = env.step(action)
s_ = state_coding(observation_)
action_ = choose_action(s_)
if done:reward_ = -5
action_chain.append([s_,action_,reward_])
if done:
update_action_chain(action_chain)
break
s,action,reward = s_,action_,reward_
return test()
env = gym.make('CartPole-v1')
Q_table = defaultdict(lambda:np.zeros((3,)))
while True:
if train(100)==500 and test(n=20000) == 500:
while True:test(n=1,visiable=True)
env.close()