Q-Learning算法
整个算法就是一直不断更新 Q table 里的值, 然后再根据新的值来判断要在某个 state 采取怎样的 action. Qlearning 是一个
off-policy 的算法, 因为里面的 max action 让 Q table
的更新可以不基于正在经历的经验(可以是现在学习着很久以前的经验,甚至是学习他人的经验).
Q-learning中的Q函数
- s: 当前状态state
- a: 从当前状态下,采取的行动action
- s’: 今次行动所产生的新一轮state
- a’: 次回action
- R: 本次行动的奖励reward
- α: 学习速率,比如取0.01
- γ : 折扣因数,表示牺牲当前收益,换区长远收益的程度。比如取0.9
要走的迷宫矩阵
是一个 5*6 的矩阵其中 0 表示可走,1 表示障碍物
代码整体实现
代码中 q_table 样式
up down left right (0, 0) -0.550747 -0.533564 -0.644566 -0.410420 (0, 1) -
0.811724 -0.344330 -0.362692 -0.354689 (0, 2) -0.510908 -0.571715 -0.354768 -
0.354741 (1, 1) -0.297905 -0.247055 -0.478024 -0.537521 (0, 3) -0.599642 -
0.512899 -0.354843 -0.354771 (0, 4) -0.546996 -0.470504 -0.354866 -0.354824 (0,
5) -0.370004 -0.361741 -0.354866 -0.397040 (2, 1) -0.259938 -0.109431 -0.464743
-0.526687 (3, 1) -0.176143 -0.403094 -0.368366 0.076880 (3, 2) -0.369096 -
0.115697 -0.109689 0.296391 (4, 2) -0.069825 -0.237857 -0.136630 -0.087706 (4, 3
) -0.018432 -0.078908 -0.068174 -0.066634 (4, 4) -0.117762 -0.079410 -0.066807 -
0.066656 (3, 3) 0.533487 -0.066857 -0.045965 -0.223937 (2, 3) -0.164942 0.020808
-0.152385 0.767553 (4, 5) -0.069677 -0.069658 -0.066724 -0.098813 (2, 4) -
0.049835 -0.063313 0.059299 0.993430 (2, 5) 0.000000 0.000000 0.000000 0.000000
q-table 为 DataFrame 类型,index 表示状态( state ),对应迷宫矩阵的索引,columns 表示动作( action )
首先运行 train()
import numpy as np import pandas as pd import random import pickle from
sklearn.utilsimport shuffle # 迷宫矩阵 maze = np.array( [[0, 0, 0, 0, 0, 0, ], [1, 0
,1, 1, 1, 1, ], [1, 0, 1, 0, 0, 0, ], [1, 0, 0, 0, 1, 1, ], [0, 1, 0, 0, 0, 0,
]] ) print(pd.DataFrame(maze))# 起点 start_state = (0, 0) # 终点 target_state = (2,
5) # 要保存的q_table的文件路径 q_learning_table_path = 'q_learning_table.pkl' class
QLearningTable: def __init__(self, alpha=0.01, gamma=0.9): # self.alpha
self.gamma 是Q函数中需要用到的两个参数 self.alpha = alpha self.gamma = gamma # 奖励(惩罚)值
self.reward_dict = {'reward_0': -1, 'reward_1': -0.1, 'reward_2': 1} # 动作
self.actions = ('up', 'down', 'left', 'right') self.q_table =
pd.DataFrame(columns=self.actions)def get_next_state_reward(self,
current_state, action): """ :param current_state: 当前状态 :param action: 动作
:return: next_state下个状态,reward奖励值,done游戏是否结束 """ done = False if action == 'up'
: next_state = (current_state[0] - 1, current_state[1]) elif action == 'down':
next_state = (current_state[0] + 1, current_state[1]) elif action == 'left':
next_state = (current_state[0], current_state[1] - 1) else: next_state =
(current_state[0], current_state[1] + 1) if next_state[0] < 0 or next_state[0]
>= maze.shape[0] or next_state[1] < 0 or next_state[1] >= maze.shape[1] \ or
maze[next_state[0], next_state[1]] == 1: # 如果出界或者遇到1,保持原地不动 next_state =
current_state reward = self.reward_dict.get('reward_0') #
此处done=True,可理解为进入陷阱,游戏结束,done=False,可理解为在原地白走一步,受到了一次惩罚,但游戏还未结束 # done = True
elif next_state == target_state: # 到达目标 reward = self.reward_dict.get('reward_2'
) done =True else: # maze[next_state[0],next_state[1]] == 0 reward =
self.reward_dict.get('reward_1') return next_state, reward, done #
根据返回的reward和next_state更新q_table def learn(self, current_state, action, reward,
next_state): self.check_state_exist(next_state) q_sa =
self.q_table.loc[current_state, action] max_next_q_sa =
self.q_table.loc[next_state, :].max()# 套用公式:Q函数 new_q_sa = q_sa + self.alpha *
(reward + self.gamma * max_next_q_sa - q_sa)# 更新q_table
self.q_table.loc[current_state, action] = new_q_sa#
如果state不在q_table中,在q_tabel中添加该state def check_state_exist(self, state): if state
not in self.q_table.index: self.q_table.loc[state] =
pd.Series(np.zeros(len(self.actions)), index=self.actions)# 旋转执行动作 def
choose_action(self, state, random_num=0.8): series =
pd.Series(self.q_table.loc[state])#
以0.8的概率执行action,尝试更多的可能性。总是做最好的选择,意味着你可能会错过一些从未探索的道路。 #
为了避免这种情况,可以添加一个随机项,而未必总是选择对当前来说最好的action。 if random.random() > random_num:
action = random.choice(self.actions)else: #
因为pd.Series数据的最大值可能出现多个,而argmax()只取第一个,故使用sklearn中的shuffle将其打乱顺序, #
随机选取最大值的索引,选取最大值的action有利于q_table快速收敛 ss = shuffle(series) action = ss.argmax()
return action # 训练 def train(): q_learning_table = QLearningTable() # 迭代次数
iterate_num =500 for _ in range(iterate_num): # 每次迭代 从start_state开始
current_state = start_statewhile True: #
先检查current_state是否已在q_table中,注意将current_state以为字符串的形式存到q_table中
q_learning_table.check_state_exist(str(current_state))# 获取当前状态的执行动作 action =
q_learning_table.choose_action(str(current_state))#
根据当前状态current_state和动作action,获取下个状态next_state,奖励值reward以及游戏是否结束done next_state,
reward, done = q_learning_table.get_next_state_reward(current_state, action)#
开始学习,更新q_table q_learning_table.learn(str(current_state), action, reward,
str(next_state))# 如果游戏结束,跳出while循环,进入下次迭代 if done: break # current_state跳转到下个状态
current_state = next_state print('game over') #
保存对象q_learning_table到文件q_learning_table_path with open(q_learning_table_path,
'wb') as pkl_file: pickle.dump(q_learning_table, pkl_file)
tain() 运行完后生成一个文件 q_learning_table.pkl,里面存放到是训练好的 QLearningTable 对象模型
然后运行下面一段代码 predict() 用来测试模型
# 预测 def predict(): # 读取q_table with open(q_learning_table_path, 'rb') as
pkl_file: q_learning_table = pickle.load(pkl_file) print('start_state:{}'
.format(start_state)) current_state = start_state step =0 while True: step =
step +1 action = q_learning_table.choose_action(str(current_state), random_num=1
)# 预测阶段,reward用不到了,故使用_代替 next_state, _, done =
q_learning_table.get_next_state_reward(current_state, action)# 输出动作和下个状态 print(
'step:{step}, action: {action}, state: {state}'.format(step=step,
action=action, state=next_state))# 如果done或者步数超过100,游戏结束退出 if done or step > 100:
if next_state == target_state: print('success') else: print('fail') break #
跳转到下个状态 else: current_state = next_state
运行结果
start_state:(0, 0) step:1, action: right, state: (0, 1) step:2, action: down,
state: (1, 1) step:3, action: down, state: (2, 1) step:4, action: down, state: (
3, 1) step:5, action: right, state: (3, 2) step:6, action: right, state: (3, 3)
step:7, action: up, state: (2, 3) step:8, action: right, state: (2, 4) step:9,
action:right, state: (2, 5) success
参考
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/2_Q_Learning_maze
<https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/tree/master/contents/2_Q_Learning_maze>
热门工具 换一换