目录

在Sarsa的基础上改进的sarsa lambda算法

Sarsa存在的问题

改进方法2:Sarsa Lambda

参考


开始每天被老师抓着写周报,以后想摸鱼都摸不了,心态baozha……

在Sarsa的基础上改进的sarsa lambda算法

算法流程和数学推导就不写了,弄清楚lambda的含义:

  • 如果 lambda = 0, Sarsa-lambda 就是 Sarsa, 只更新获取到 reward 前经历的最后一步.
  • 如果 lambda = 1, Sarsa-lambda 更新的是 获取到 reward 前所有经历的步.

lambda表示想要选择的步数,是一个衰减值

和之前的奖励衰减值一样,lambda是脚步衰减值

Sarsa存在的问题

经过上一次的训练:https://xduwq.blog.csdn.net/article/details/105826501

能发现Sarsa存在这一个很严重的问题:由于Sarsa是一种保守的算法,代价是经常陷入局部最优,也可以称之为过拟合现象,卡在某个步骤畏惧不前,原本如果实验成功,将在1000~3000步数之内完成训练,可是如果陷入了局部最优的话,训练几万步都不会成功,可以认为是实验失败!

改进办法1:降低negative reward的值,可以明显改善实验效果!

这里吧negative reward降低为-0.5!

改进方法2:Sarsa Lambda

主要RL_brain.py进行了改动,其余代码和Sarsa一样!

import numpy as np
import pandas as pdclass RL(object):def __init__(self, action_space, learning_rate=0.01,reward_decay=0.9,e_greedy=0.9):self.actions = action_space  # a listself.lr = learning_rateself.gamma = reward_decayself.epsilon = e_greedyself.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)def check_state_exist(self, state):if state not in self.q_table.index:# append new state to q tableself.q_table = self.q_table.append(pd.Series([0] * len(self.actions),index=self.q_table.columns,name=state,))def choose_action(self, observation):self.check_state_exist(observation)# action selectionif np.random.rand() < self.epsilon:# choose best actionstate_action = self.q_table.loc[observation, :]# some actions may have the same value, randomly choose on in these actionsaction = np.random.choice(state_action[state_action == np.max(state_action)].index)else:# choose random actionaction = np.random.choice(self.actions)return actiondef learn(self, *args):pass# 离线学习QLearning
class QLearningTable(RL):def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)def learn(self, s, a, r, s_):self.check_state_exist(s_)q_predict = self.q_table.loc[s, a]if s_ != 'terminal':    # next state is not terminalq_target = r + self.gamma * self.q_table.loc[s_, :].max()  # Q-Learning中是选择最大值else:q_target = r  # next state is terminalself.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # update# 在线学习SarsaLambdaTable
class SarsaLambdaTable(RL):# 初始化def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9, trace_decay=0.9):super(SarsaLambdaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)self.lambda_ = trace_decayself.eligibility_trace = self.q_table.copy()def check_state_exist(self, state):if state not in self.q_table.index:# 在qtable中添加新的stateto_be_append = pd.Series([0] * len(self.actions),index = self.q_table.columns,name = state)self.q_table = self.q_table.append(to_be_append)self.eligibility_trace = self.eligibility_trace.append(to_be_append)# 学习更新参数def learn(self, s, a, r, s_, a_):self.check_state_exist(s_)  # 检查状态是否存在q_predict = self.q_table.loc[s, a]if s_ != 'terminal':  # next state is not terminalq_target = r + self.gamma * self.q_table.loc[s_, a_]  # 直接选择下一个行动的值else:q_target = r  # next state is terminal# self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新值error = q_target - q_predict# 不可或缺性self.eligibility_trace.ix[s,:] *= 0self.eligibility_trace.ix[s,a] = 1# 更新Q tableself.q_table += self.lr * error * self.eligibility_traceself.eligibility_trace *= self.gamma*self.lambda_

参考

主要复现莫烦Python:https://zhuanlan.zhihu.com/p/24860793

文末再膜一下莫烦dalao

沉迷单车的追风少年
原创文章 384获赞 354访问量 18万+
关注私信
展开阅读全文