{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Finding the optimal policy\n",
"\n",
"Let's try to find the optimal policy using Poicy iteration an Value iteration in a simple grid world.\n",
"\n",
"The world and the optimal policy to be gound are the following:\n",
"\n",
"\n",
"\n",
"Next we have the implementation of the gridworld (see previous notebook for explanations)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"#Define the state matrix\n",
"state_matrix = np.zeros((3,4))\n",
"state_matrix[0, 3] = 1\n",
"state_matrix[1, 3] = 1\n",
"state_matrix[1, 1] = -1\n",
"\n",
"#Define the reward matrix\n",
"reward_matrix = np.full((3,4), -0.04)\n",
"reward_matrix[0, 3] = 1\n",
"reward_matrix[1, 3] = -1\n",
"\n",
"transition_matrix = np.array([[0.8, 0.1, 0.0, 0.1],\n",
" [0.1, 0.8, 0.1, 0.0],\n",
" [0.0, 0.1, 0.8, 0.1],\n",
" [0.1, 0.0, 0.1, 0.8]])\n",
"\n",
"gamma = 0.999\n",
"\n",
"def prob_next_state(position,world_row=3,world_col=4):\n",
" \"Given a position in the grid-world, returns a list of possible next states\"\n",
" posible=[]\n",
" for action in range(4):\n",
" position_def = position\n",
" if(action == 0): new_position = (position[0]-1, position[1]) #UP\n",
" elif(action == 1): new_position = (position[0], position[1]+1) #RIGHT\n",
" elif(action == 2): new_position = (position[0]+1, position[1]) #DOWN\n",
" elif(action == 3): new_position = (position[0], position[1]-1) #LEFT\n",
" else: raise ValueError('The action is not included in the action space.')\n",
"\n",
" #Check if the new position is a valid position\n",
" if (new_position[0]>=0 and new_position[0]=0 and new_position[1] \"\n",
" elif(policy_matrix[row,col] == 2): \n",
" U[row,col]=0 \n",
" R[row,col]=-1 \n",
" #policy_string += \" v \" \n",
" elif(policy_matrix[row,col] == 3):\n",
" U[row,col]=-1 \n",
" R[row,col]=0 \n",
" #policy_string += \" < \"\n",
" elif(np.isnan(policy_matrix[row,col])): \n",
" U[row,col]=0 \n",
" R[row,col]=0 \n",
" #policy_string += \" # \"\n",
" plt.rcParams['figure.figsize'] = (4,3)\n",
" if V.size==0:\n",
" V=np.ones(policy_matrix.shape)\n",
" for x in lstates:\n",
" V[x.multi_index]=0.7\n",
" plt.matshow(V,extent=[0, 4, 0, 3],vmin=0, vmax=1)\n",
" else:\n",
" plt.matshow(V,extent=[0, 4, 0, 3])\n",
" plt.grid()\n",
" X, Y = np.meshgrid(np.arange(0.5, 4.5, 1), np.arange(0.5, 3.5, 1))\n",
" Q = plt.quiver(X, Y,U,R)\n",
"\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def return_expected_action(state):\n",
" \"\"\"Return the expected action.\n",
" \"\"\"\n",
" lV=[]\n",
" for action in range(0,4):\n",
" acum=0\n",
" for x in zip(prob_next_state(state.multi_index,3,4),transition_matrix[int(action),:]):\n",
" acum = acum + x[1]*(reward_matrix[x[0]]+gamma*V[x[0]])\n",
" lV.append(acum)\n",
" return np.argmax(lV)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Policy iteration algorithm"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Initial policy:\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAADMCAYAAABTJB73AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAADuZJREFUeJzt3X1wVOW9B/DvL5sXEmJIIAlvAakKucW0cKuX64ilirSNlCmdW6+j9tbpFC/j1DIpaquda2s73le9emundDqMUu3UW4bbaqsbi1kkmnpBeTNgAsKgtZdAgAAJSXY32WTzu3/sFsPb5pCek+ecfb6fmQxZ9uyZ75Pku8/Zs+fsEVUFEdklx3QAIhp7LD6RhVh8Igux+EQWYvGJLMTiE1nI1eKLSK2I7BeRgyLykJvrNk1E1onIcRFpMZ3FCyIyQ0QaRWSfiLSKSJ3pTG4SkXEisk1EdqfH90PTmdwmIiEReUdEwiMt61rxRSQEYA2AWwDMBXCHiMx1a/0+8CyAWtMhPDQI4H5V/TiA6wDcm2W/v34Ai1V1HoD5AGpF5DrDmdxWB2CfkwXdnPEXADioqh+oagLAegDLXVy/UaraBOCU6RxeUdV2Vd2V/r4HqT+g6WZTuUdTetM389JfWXP0mohUAfgCgKedLO9m8acDODTsdhuy6A/HJiIyC8BfA3jbbBJ3pTeFmwEcBxBR1Wwa348AfAfAkJOF3Sy+XOD/suYZ1RYiUgzgNwC+pardpvO4SVWTqjofQBWABSJSYzqTG0RkGYDjqrrT6WPcLH4bgBnDblcBOOLi+sljIpKHVOmfV9UXTOfxiqp2AXgd2bPPZiGAL4rIh0i9xF4sIr/M9AA3i78dwGwR+ZiI5AO4HcBLLq6fPCQiAuAZAPtU9UnTedwmIhUiUpr+vhDAEgDvmU3lDlX9rqpWqeospHq3WVX/IdNjXCu+qg4C+CaAV5HaMbRBVVvdWr9pIvIrAFsBVItIm4isMJ3JZQsBfBWp2aI5/bXUdCgXTQXQKCJ7kJqkIqo64tte2Up4Wi6RfXjkHpGFWHwiC7H4RBZi8YksNGLxR3Nyg4isdCee/2Tz2ACOL+icjs/JjD+akxuy+YebzWMDOL6gczS+3JEW0NT7fVl7cgORjRy9j58+5XYngKsArFHVBy+wzEqkn20KCvKvuXzmZJej+sPAYA7ych2dBxFIHF+wHT7cht7o0IXOmznLJR3Akz7k8UUAq1T1oh9IUT1npu5rKnC83iBpaqnDopqnTMfwDMcXbAs+fwg7dveNWPxL2qufhSc3EFnJyV79rD25gchWI+7cQ+rkhufSr/NzkDr5xtqTG4iygZO9+nuQ+jQWIsoSPHKPyEIsPpGFWHwiC7H4RBZi8YksFPjiDwzwtAGiSxX44q9Zd9p0BKLACXTxO7uS+OETpxCPZ+9JF0ReCHTxNzbG0N0zhNfejJuOQhQogS5+OBIFANSn/yUiZwJb/IEBxcbNMQBA/aYoeH0AIucCW/wt2/vQdTr12v5wexLvvNtvOBFRcAS2+OFzNu/rN8UMJSEKnsAW/72DCdy2vBgAcOffFePdfZzxiZwKZPFVFS/+fCqW3lwEAPjysmI8/9MphlMRBYeTD+LwHRFB7jnJ8/JG/JgxIkoL5IxPRH8ZFp/IQiw+kYVYfIO6e4aQTGbvgUedXUnTEegiArlzL8g++NMA6iNRvJw+DuHV9dMMJ3KPqqK5JYFwJIpwQxSfvq4Q//mDctOx6AJYfI8lk4q3dvalyxDD3gOJM/d97sYifPdfTmZ8/NfvKMGcK/O9jjlq8fgQGrfEEW6Ion5TDG1HBs/cN7c6Hw/984mMj//efRMxvig7Nzw3bo5i+tRcfOLj/ruqVKCLf9PCIvz2uam4dp7/frAAkEgo1qzrws/X96B1f+K8+xtej6Hh9cxHHC6+oci3xe/sSuJfn+rEhpd6zyr8n/1iQ8+I6/j2N8owvsiLdOY98tgp3LiwEP/xPf/9fQa6+FXTclE1zb9DyM8XrL6nDKvvKTtrE79paxwDA0DdP07AvV8vzbiOqZWhMUp76cpKQ3j8kXI89v1J2N2aQLghinAkiu3NqaMo1z5RiRuvL8y4jtIJ2TnbA8CW+irk+HR4/m1Nlrni8jysursUq+4uRXfPEBreiGHbrj7MmpGLUCjYBx+JCObXFGB+TQEevm8i2o8Non5TFF2nk7hyVp7peMb4+ffK4htQclkObl1WjFuXFZuO4ompk3Nx91cmmI5BGfh0Q4SIvMTiE1mIxSeyEItPZCEWn8hCLD6RhVh8Igux+EQWYvGJLMTiE1mIxSeyEItPZCEWn8hCLD6RhVh8Igux+EQWYvGJLMTiE1mIxSeyEItPZCEWn8hCIxZfRGaISKOI7BORVhGpG4tgROQdJx+vPQjgflXdJSKXAdgpIhFV3etxNiLyyIgzvqq2q+qu9Pc9APYBmO51MCLyjqg6v0yziMwC0ASgRlW7z7lvJYCVAFBRUX7Nhv9+xL2UPtIbn4ziwmOmY3iG4wu2B+5/ADt29414CR/HV9IRkWIAvwHwrXNLDwCquhbAWgConjNTF9U8dQlxg6OppQ7ZOjaA47OFo736IpKHVOmfV9UXvI1ERF5zsldfADwDYJ+qPul9JCLympMZfyGArwJYLCLN6a+lHuciIg+N+BpfVd8E4N/r/RLRJeORe0QWYvGJLMTiE1mIxSeyEItPZCEWn8hCLD6RhVh8Igux+EQWYvGJLMTiE1mIxSeyEItPZCEWn8hCLD6RhVh8Igux+EQWYvGJLMTiE1nIt8VXVbTu78fJU0nTUWgU4vEhbHunz3QMughfFT+RUETeiKHu4Q7Mvu5PuHXFUZRO8FVEyqD92CCefv40vvS1dlRe/Uf84a246Uh0EY6vpOOVjhNJvLI5inBDFJE3Yujp/eiSXjOm5aL29iMZH/+120vwlS9f5nVMI3Y092HDS7147PvlpqNckKqiuSWB+kgU4UgU25v7z7r/hfpebNwcy7iOXz8zBRNKQl7GNObJn3Wi+qp8fGHJeNNRzmO0+KqKd9/rx649qa/hpQeA4yeT6O4dyriOW24e9DKiUYePDmLXnn4MDCjy8vz3Ceenu4ew691+7NzTj9b9ifPub92fQE5O5tzJLH4l19ySQF6u/35vgOHiiwgW31CExTcU4UePlmPvgQTCDTHUR6LYurMPC+YXoPHF6UhdzMc+y2uLsby22HSMiyqdEMKKO0uw4s4SxONDaNwST8/+MbQdGcSaf6/M2q0xJ37xk8mmI1yU8U39PxMRXF1dgKurC/DgqjKcOJnEK69F0XEyicpy38SkiygszMHSm8dj6c3j8ZN/U+xuTWD/++dvBZA/+LZR5ZNCuOu2EtMxaBREBPNrCjC/psB0FLoI7jInshCLT2QhFp/IQiw+kYVYfCILsfhEFmLxiSzE4hNZiMUnshCLT2QhFp/IQiw+kYVYfCILsfhEFmLxiSzE4hNZiMUnshCLT2QhFp/IQiw+kYVYfCILjVh8EVknIsdFpGUsAhGR95zM+M8CqPU4BxGNoRGLr6pNAE6NQRYiGiOuXVBDRFYCWAkAFRXlaGr5jlur9pXe+GQ0tdSZjuEZji/oHnC0lGvFV9W1ANYCQPWcmbqo5im3Vu0rTS11yNaxARyfLbhXn8hCLD6RhZy8nfcrAFsBVItIm4is8D4WEXlpxNf4qnrHWAQhorHDTX0iC7H4RBZi8YksxOITWYjFJ7IQi0/GJBKK9mODpmNYicUnY97YGsdvfx81HcNKLD4ZE45E8XIDi2+CayfpEF0KVUW4IYojxwbRGx1C8XjOQWOJP20yYu+BBD48NIhEAtjUFDMdxzosPhkRbvio7OEIN/fHGotPRgwv+yubYhgaUoNp7MPi05jrOJHE1h19EEndPtaRxPbmfrOhLMPi05g7dGQAG9dPw/XXjkNleQjbNlahp3fIdCyrcK8+jblPfXIcAODRJ1Kf4XrNvHEm41iJMz6RhVh8Igux+EQWYvGJLOTb4necSOK5Dd04foJnb5G/DAwoXt8SC/QRh77Zq6+qaN2fQLghivpNMWzd0YcbFozDXX9/meloROjsSuL3m2MIR6J4tTGG091D2N04w3SsUTNafFXF5jfj+N3GKOojUXx46OzZfVtzPyb91R8zruPh1WW4754yL2Mac7o7iWhMMbkihFBITMdx3cSyECZX+Pf9+/9rG8D/vNyLcEMU/7u9D8nk2fd/+ouHMz5+5vRcNG+e6WHC0TNafBHBvLkFONw+iKPHB3GyM4me3o8O3aycFMLsK/IyrmNyhW82Wlz3g8dP4cdPn8aHOy7HjOmZfw5B9OKzU01HyGjalFxcO28cjh5P4mhHEgfeHzhzXygEfOoTBWeOPryQyorQGKQcHeOtKZ8Uwl23leCu20qQSCia3ooj3BBFOBJFXp5g4/ppWTnbkf/l5go+c30hPnN9IR5/pBwH3k+gPhJFOBLDH96O459Wl+GmhUWmY46K8eIPl58vWLKoCEsWFeG/Hi3H3gMJdJ0ewqSJ/n3mJHvMuTIfc67Mx+p7ytDZlURbe3B3PPuq+MOJCK6uLjAdg+iCykpDKCsN7oTk27fziMg7LD6RhVh8Igux+EQWYvF9qONEEnUPd+BoR+qIkVdei+HJn3UaTkXZhMX3oYryEN58O44Nv+sFAHzjwQ5MqfTtGzAUQCy+Ty377Pgz34dCQO1NwTxQhPyJxfepZZ/7qPgL/2YcJpYF9z1j8h8W36eu+WQBplSmyj78SYDIDSy+T+XkCJYuSRV++GY/kRtYfB9b9tkizL4iD9VX5ZuOQlmGu4p9bMmiIty6jBeaIPdxxvex8UU5+Pa9paZjUBZi8X1uQgn35pP7WHwiC7H4RBZi8YksxOITWYjFJ7IQi09kIRafyEKOii8itSKyX0QOishDXociIm+NWHwRCQFYA+AWAHMB3CEic70ORkTecTLjLwBwUFU/UNUEgPUAlnsbi4i85OQknekADg273Qbgb89dSERWAliZvtkfmoqWvzyeH60qB3DCdArvcHwBV+1kISfFv9CF6/S8/1BdC2AtAIjIDlW91kmAoMnmsQEcX9CJyA4nyznZ1G8DMPxC4FUAjowmFBH5g5PibwcwW0Q+JiL5AG4H8JK3sYjISyNu6qvqoIh8E8CrAEIA1qlq6wgPW+tGOJ/K5rEBHF/QORqfqJ73cp2IshyP3COyEItPZCEWn8hCLD6RhVh8Igux+EQWYvGJLPT/HOJNQvqySkMAAAAASUVORK5CYII=\n",
"text/plain": [
"