{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Policy Evaluation\n", "\n", "Let's try to evaluate a policy using exact algebraic method and Value iteration method.\n", "\n", "The world and the policy are the following:\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 0. Defining the environment\n", "## - Kind of states\n", "Define state matrix for each cell: 1 means terminal state, 0 means normal state, -1 means impossible state" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "State Matrix:\n", "[[ 0. 0. 0. 1.]\n", " [ 0. -1. 0. 1.]\n", " [ 0. 0. 0. 0.]]\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "#Define the state matrix\n", "state_matrix = np.zeros((3,4))\n", "state_matrix[0, 3] = 1\n", "state_matrix[1, 3] = 1\n", "state_matrix[1, 1] = -1\n", "print(\"State Matrix:\")\n", "print(state_matrix)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## - Reward function\n", "Define reward matrix for each cell: All states -0.4, except for terminal states. Reward for corner terminal state is 1, for the other terminal state -1." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reward Matrix:\n", "[[-0.04 -0.04 -0.04 1. ]\n", " [-0.04 -0.04 -0.04 -1. ]\n", " [-0.04 -0.04 -0.04 -0.04]]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUgAAAEECAYAAABKjq0kAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAACQFJREFUeJzt3VGIpXd5x/Hfs9kdFSzkYksqs1kii4SGQBVLEAJeBKGrluamFwbqlbAUVGJYG9Ir8b4YBb1ZTFCoRAQVRCwSaEoo2NaYRkm6FYKUZFcxlLSNChriPr2YuYjJPp53zZmc48znAwNzZv7n5eG/w3fec868Z6u7A8CrHdv0AADbSiABBgIJMBBIgIFAAgwEEmCw1kBW1dmq+mFVPV1V963z2IdJVT1YVc9V1ZObnmWbVdWNVfVIVV2sqqeq6u5Nz7SNquqNVfVvVfX9/X365KZn2mZVdV1V/XtVfXPV2rUFsqquS/K5JO9NckuSu6rqlnUd/5D5QpKzmx7i98BLSc539x8neVeSD/uZuqpfJbmju/8kyduTnK2qd214pm12d5KLSxau8wzytiRPd/ePuvvFJF9Ocucaj39odPejSZ7f9Bzbrrt/0t2P73/+s+z9UO9udqrt03t+vn/zxP6HK0CuoqpOJXl/ks8vWb/OQO4mefZlty/FDzNrUlU3JXlHkn/d7CTbaf9h4xNJnkvycHfbp6v7dJJ7k1xZsnidgayrfM1vMV6zqnpzkq8m+Vh3v7DpebZRd/+6u9+e5FSS26rq1k3PtG2q6s+TPNfd31t6n3UG8lKSG192+1SSH6/x+BxBVXUie3H8Und/bdPzbLvu/t8k/xTPcV/N7Un+oqr+K3tPAd5RVX//2+6wzkB+N8nbquqtVbWT5ANJvrHG43PEVFUleSDJxe7+1Kbn2VZV9YdVdf3+529K8p4k/7nZqbZPd/9td5/q7puy16d/7O6/+m33WVsgu/ulJB9J8u3sPZn+le5+al3HP0yq6qEk30lyc1VdqqoPbXqmLXV7kg9m7zf9E/sf79v0UFvoLUkeqaofZO9E5eHuXvknLKxW3u4M4OpcSQMwEEiAgUACDAQSYLAykC6EB46qJWeQ13whfFWdW8dwh519WsY+LWevllm6TysD+TteCO8faRn7tIx9Ws5eLbOeQCYuhAeOpmv6Q/H9y5m+nuSj3f3kK753LvtV3nnDG955w6438lnl2JUruXLM62Sr2Kfljvevc+L4ojeqOdIuX76Un//iytXeYOc3XPOVNFX1iSS/6O6/m9acPnOmT/z1h6/puEfRPad3c/8zlzc9xtazT8s98I6dvPvWz2x6jK132589m8e+/8uVgVzyKrYL4YEj6fiCNW9J8sX9/1LhWPbehMKF8MChtzKQ3f2D7L2TM8CR4plvgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgxWBrKqbqyqR6rqYlU9VVV3vx6DAWza8QVrXkpyvrsfr6o/SPK9qnq4u//jgGcD2KiVZ5Dd/ZPufnz/858luZhk96AHA9i06u7li6tuSvJoklu7+4VXfO9cknNJcvLkyXd+4rOfXd+Uh9QNOzv56YsvbnqMrWeflvujHM//XPq/TY+x9c5//ON5oZ+vVeuWPMROklTVm5N8NcnHXhnHJOnuC0kuJMnpM2f6/mcuX8O4R9M9p3djn1azT8vdd+z6fOVv/mHTYxwai17FrqoT2Yvjl7r7awc7EsB2WPIqdiV5IMnF7v7UwY8EsB2WnEHenuSDSe6oqif2P953wHMBbNzK5yC7+5+TrHwyE+CwcSUNwEAgAQYCCTAQSICBQAIMBBJgIJAAA4EEGAgkwEAgAQYCCTAQSICBQAIMBBJgIJAAA4EEGAgkwEAgAQYCCTAQSICBQAIMBBJgIJAAA4EEGAgkwEAgAQYCCTAQSICBQAIMBBJgIJAAA4EEGAgkwEAgAQYCCTAQSICBQAIMBBJgIJAAA4EEGAgkwEAgAQYCCTAQSICBQAIMBBJgIJAAA4EEGAgkwEAgAQYCCTAQSICBQAIMBBJgIJAAA4EEGAgkwEAgAQYCCTAQSICBQAIMBBJgIJAAA4EEGAgkwGBlIKvqwap6rqqefD0GAtgWS84gv5Dk7AHPAbB1Vgayux9N8vzrMAvAVjm+rgNV1bkk55Lk5MmTuff07roOfWjdsLOTe+zTSvZpuet3dnLnQ3+56TG23qPn/2XRurUFsrsvJLmQJKfPnOn7n7m8rkMfWvec3o19Ws0+LWev1sur2AADgQQYLPkzn4eSfCfJzVV1qao+dPBjAWzeyucgu/uu12MQgG3jITbAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAQCABBgIJMBBIgIFAAgwEEmAgkAADgQQYCCTAYFEgq+psVf2wqp6uqvsOeiiAbbAykFV1XZLPJXlvkluS3FVVtxz0YACbtuQM8rYkT3f3j7r7xSRfTnLnwY4FsHlLArmb5NmX3b60/zWAQ+34gjV1la/1qxZVnUtybv/mr3Lv+Sdfy2BHwUeTk0n+e9NzbDv7tJy9WuzmJYuWBPJSkhtfdvtUkh+/clF3X0hyIUmq6rHu/tMlAxxl9mkZ+7ScvVqmqh5bsm7JQ+zvJnlbVb21qnaSfCDJN17LcAC/D1aeQXb3S1X1kSTfTnJdkge7+6kDnwxgw5Y8xE53fyvJt67huBd+t3GOHPu0jH1azl4ts2ifqvtVr7cAEJcaAowEEmAgkAADgQQYCCTAQCABBgIJMPh/9EW6o+bXfk0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#Define the reward matrix\n", "reward_matrix = np.full((3,4), -0.04)\n", "reward_matrix[0, 3] = 1\n", "reward_matrix[1, 3] = -1\n", "print(\"Reward Matrix:\")\n", "print(reward_matrix)\n", "plt.matshow(reward_matrix,extent=[0, 4, 0, 3])\n", "plt.grid()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## - Definition of the policy and how to pretty-print it. \n", "Each action is represented by a number in the policy matrix. Action (Up) is represented by 0, (Rigth) by 1, (Down) by 2 and, finally, (Left) by 3. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAADMCAYAAABTJB73AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4wLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvFvnyVgAADo5JREFUeJzt3XlwnPV9x/H3V2vJumJbPjCOD2BI7MYxgxs7LsEEggOOiZXQI9MBJsdM3bowMTXXUNI0pWGY6bRMSD2JZxgPRxrqQiCkBSTAlrBB4cYGQwQOxEDAsjEK8aldHavVt39IBJtg7Vrso2ef/X1eMx4k/Gjn+3j3vc+zq+fZx9wdEQlLRdwDiMjoU/giAVL4IgFS+CIBUvgiAVL4IgEqavhmtszMXjGzHWZ2TTFvO25mdquZdZpZe9yzRMHMZprZZjPbbmYvmdnquGcqJjOrNrNnzOyFofX7ftwzFZuZpczseTNryrds0cI3sxSwFjgPmAtcaGZzi3X7JeAnwLK4h4hQP3Clu38KOA34dpndf73AEnc/FZgPLDOz02KeqdhWA9sLWbCYW/xFwA53f93d+4A7gfOLePuxcvc2YG/cc0TF3d929+eGvj7E4ANoerxTFY8P6hr6tnLoT9kcvWZmM4DlwM2FLF/M8KcDOw/7voMyeuCExMxOBP4UeDreSYpraFd4G9AJtLh7Oa3ffwJXAwOFLFzM8O1D/l/ZPKOGwszqgXuAy9z9YNzzFJO759x9PjADWGRm8+KeqRjMrBHodPethf5MMcPvAGYe9v0MYHcRb18iZmaVDEa/3t1/Efc8UXH3/cAjlM97NouBr5rZbxl8ib3EzP57uB8oZvjPAp80s5PMrAq4ALiviLcvETIzA24Btrv7jXHPU2xmNsXMJgx9XQOcA/w63qmKw92/4+4z3P1EBrvb5O5fH+5niha+u/cDq4ANDL4xdJe7v1Ss24+bmd0BPAnMMbMOM1sR90xFthj4BoNbi21Df74c91BFNA3YbGYvMriRanH3vL/2Klem03JFwqMj90QCpPBFAqTwRQKk8EUClDf8kZzcYGYrizNe6SnndQOtX9IVun6FbPFHcnJDOf/jlvO6gdYv6QpavzH5FvDB3/eV7ckNIiEq6Pf4Q6fcbgU+Aax193/8kGVWMvRsM3Zs1YITZk0t8qilIdtfQeWYgs6DSCStX7Lt2tVBV3rgw86bOcIxHcAzdMjj/wKXuvtRP5BizuxZvr1tbMG3myRt7as5c96auMeIjNYv2RZ9aSdbXujJG/4xvatfhic3iASpkHf1y/bkBpFQ5X1zj8GTG/5r6HV+BYMn3wR7coNIOSjkXf0XGfw0FhEpEzpyTyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRAZR9+V3qAbLZ8L/yz/0COY7k2QtLs25+Le4SyVPbh9/U5sz/3JhdevIf19xxi777yeiC98lqWExb8lkuu7qSpJU13d3ldJebOe7uYd+abXHP9uzz2dDe5XPk+yY2mQj5eu6Q9tCnNI090D7tMXa1x171d3HVvF6kULP5sNY1L61h+Th1zPlGJWd4Lj8Tm5vUH2PFGdthlMt3OutsPsu72g9RUG188s4bGc+toPLeOaVNL+y6+7gd7yQzzZNXb62z/TZbtv9nPDWv3M7GhgvOW1NG4tJYvfaGW8eNSozjtsXloU5rp08ZwyqdK76pSpf2oKMAvn+rhhrX7C14+l4O2p3p4/NkemjamufziBr6ytLZk47/7vi5a24Z/Yjtcd4/TtDHDw23dPNCa4Z8vb2DBqdURTvjR/OiW/ezdV/heyt59A6y/5xAPbUpz/nn1/MsVDcycXhnhhCN37X/s5QuLa/j37yn8orvi4gn8zUXjhl3mm6v28NTWXsaPq2DZkloaz61j2dm1TGwo3a3Fe25bM5XunqPv3vb2Omf9RQd79w0wfVqK5efU0bi0jiWLa6ipKf1Xcs88NJOBYbp//c0syy7YDcDc2VU0Lq1l+Tl1fG5hNalUaT5Zv+eJ5hlUlOhdkPjwJ01MMWni0QPeuSvLaQuquf47kzhjUQ2VlaX9YPmgjx8//F306BPd/MPfTqDx3Drmz6sq2T2Xozlp1vBb68ee6eaH101m+bl1nHxiaW7Zj6aUn5gSH34+M6dX8oPvT4l7jMicdXoNZ51eE/cYkfnWXw+/NycjU6I7IiISJYUvEiCFLxIghS8SIIUvEiCFLxIghS8SIIUvEiCFLxIghS8SIIUvEiCFLxIghS8SIIUvEiCFLxIghS8SIIUvEiCFLxIghS8SIIUvEiCFLxKgvOGb2Uwz22xm283sJTNbPRqDiUh0Cvl47X7gSnd/zsw+Bmw1sxZ3fzni2UQkInm3+O7+trs/N/T1IWA7MD3qwUQkOnYsl1g2sxOBNmCeux/8wN+tBFYCTJkyecFd/3Nt8aYsIV3dU6mveSfuMSKj9Uu2q668ii0v9OS9hE/BV9Ixs3rgHuCyD0YP4O7rgHUAc2bP8jPnrTmGcZOjrX015bpuoPULRUHv6ptZJYPRr3f3X0Q7kohErZB39Q24Bdju7jdGP5KIRK2QLf5i4BvAEjPbNvTnyxHPJSIRyvsa390fA0r3er8icsx05J5IgBS+SIAUvkiAFL5IgBS+SIASH/6Bg7m4R5ARGhhwDh4aiHuMICU+/H+9YW/cI8gIbXmhl/s3puMeI0iJDr9jdz9rbzvA3n3a6idRc0ua5laFH4dEh9/UkiaXgwc3ZeIeRUagqSXNQ5syZLOFnyEqxZH48A//ryTHzl1ZtrX3ceDgAI890x33OMFJbPjpzACbHht8wGzYrK1G0jS1vr+X1qTX+aMuseG3tmXo7R2MXVuN5Dk89qaWDMfygTDy0SU2/OaWI1/Xa6uRHOnMAJsff/+JescbWV59LRvjROFJZPjuTk2NcfWqCQB874oGBj82QJJgW3sv1145kc+cMpZJDRXcdMMUXtzeF/dYQSn4o7dKiZmx5vop3H734CeAzZ83lj8/rz7mqaRQixfVsHhRDQ+0pkmljL/7+vi4RwpOIrf4IvLRKHyRACl8kQApfJEAKXyRACl8kQAp/BK2ZVsPV1/3btxjyAjdeNO+kj37MJG/x3/PXy2v54ufr6VhfHk+f+3a089zL/aSzTqVleV3gNI9t06jP1e+h+pua++jckxp3m+JDr+2toLa2vKMHuD8ZfWcv6x8D0yaPCkV9wiR+umPp8Y9wlGVbzUiclQKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRAecM3s1vNrNPM2kdjIBGJXiFb/J8AyyKeQ0RGUd7w3b0N2DsKs4jIKCnaJbTMbCWwEmDKlMm0tV9drJsuKV3dU2lrXx33GJHR+iXdVQUtVbTw3X0dsA5gzuxZfua8NcW66ZLS1r6acl030PqFQu/qiwRI4YsEqJBf590BPAnMMbMOM1sR/VgiEqW8r/Hd/cLRGERERo929UUCpPBFAqTwRQKk8EUCpPBFApT48N/cmY17BJHESXT47s4V174b9xgiiZPo8F9+tY//ezDNWx3a6osci0SH37QxA0BzaybmSUSSJdHhN7ekAbh/YzrmSUSSJbHhv/v7HE9u7QFg8+MZutIDMU8kkhyJDf+Bh9MMDLXe1wetbdrdFylUYsNvakljNvi12eD3IlKYRIbv7px9Ri0/+rcpANz8w+P47PzqmKcSSY5Ehm9mXPKt8dTXDm7yJ4yr4O+/OT7mqUSSI5Hhi8hHo/BFAqTwRQKk8EUCpPBj8sZbWW5ef4BczuMepejcned/1cvP7j0U9yiRyGadR57IJPrYkaJdUEOGl8s5T23toaklTdPGDC+/2sdlK8eTSlncoxVFd/cAmx7vprklTXNrho7d/fz0x1PjHqto9u3P8eCmDE0taTZsznDg4AAvbJ4Z91gjlujwa2oqmDY1RfXY0oynr8+5d0Oapo1pHnw4ze/3HXlY8brbD3LbncNvFX9+y/EsOaM2yjFHbP+BHD9v6qKpJcPDbRky3UfuvXz7mk4u/affDXsbO546gYkNqSjHHLG3OrLcfX8XTRvTPP5sD7nckX//+a/uGvbnZ00fw7ZNsyKccOQSHf7XGuv5WmN93GMcVVWVsfDUsbzT2c+ezn4efbKb7GFnEJ80q5KpU4Z/0I+rL91XY+PHVbDw1Gre6czxTmc/zzzfe8Tff3pOFbU1w8+fKs3mAfj48WNYeGo1ezpz7Pldjldfe//OS6XgM6eM/cPRox/muDz3bZwSHX4SnDSrklUrJrBqxQQOHhpg46MZmlvSPNCa5tN/UsUdNx0f94gjZmbMnzeW+fPG8t3LJ7Kns5/m1gxNG9O0tmX4y+X1XHlJQ9xjjtiYMcZZp9dw1uk13HDtZF59rY/mljRNLRl++XQ33728gbMXl+beWD4KfxSN+1jFH/ZScjnn2W295HJeNq/zjz9uDCsuGseKi8bR0zPAr37dF/dIRTX75Cpmn1zF5Rc3sG9/jo63++MeacQUfkxSKeO0BeV7fkF1dUVZnz/RMCFFw4TS3ZXPp3RfQIpIZBS+SIAUvkiAFL5IgBS+SIAUvkiAFL5IgBS+SIAUvkiAFL5IgBS+SIAUvkiAFL5IgBS+SIAUvkiAFL5IgBS+SIAUvkiAFL5IgBS+SIAUvkiACgrfzJaZ2StmtsPMrol6KBGJVt7wzSwFrAXOA+YCF5rZ3KgHE5HoFLLFXwTscPfX3b0PuBM4P9qxRCRKhVxQYzqw87DvO4A/++BCZrYSWDn0bW9qGu0ffbxSdOlk4N24p4iO1i/h5hSyUCHhf9j1nf7oou7uvg5YB2BmW9x9YSEDJE05rxto/ZLOzLYUslwhu/odwOEXAp8B7B7JUCJSGgoJ/1ngk2Z2kplVARcA90U7lohEKe+uvrv3m9kqYAOQAm5195fy/Ni6YgxXosp53UDrl3QFrZ+5/9HLdREpczpyTyRACl8kQApfJEAKXyRACl8kQApfJEAKXyRA/w9xvy3WLtAcEwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "policy_matrix = np.array([[1, 1, 1, -1],\n", " [0, np.NaN, 0, -1],\n", " [0, 3, 3, 3]])\n", "\n", "# Don't care about this function. Only prints the policy matrix\n", "def print_policy(policy_matrix2, V=np.array([]),lstates=[]):\n", " \"\"\" Policy_matrix: is the policy to print\n", " V: is a value function that we want to superpose with colors to the policy\n", " lstates: is list of states to highligth if I want to stress something\n", " \"\"\"\n", " policy_matrix = policy_matrix2[::-1]\n", " shape = policy_matrix.shape\n", " U=np.zeros(policy_matrix.shape)\n", " R=np.zeros(policy_matrix.shape)\n", " for row in range(shape[0]):\n", " for col in range(shape[1]):\n", " if(policy_matrix[row,col] == -1): \n", " U[row,col]=0 \n", " R[row,col]=0 \n", " elif(policy_matrix[row,col] == 0): \n", " U[row,col]=0 \n", " R[row,col]=1 \n", " #policy_string += \" ^ \"\n", " elif(policy_matrix[row,col] == 1): \n", " U[row,col]=1 \n", " R[row,col]=0 \n", " #policy_string += \" > \"\n", " elif(policy_matrix[row,col] == 2): \n", " U[row,col]=0 \n", " R[row,col]=-1 \n", " #policy_string += \" v \" \n", " elif(policy_matrix[row,col] == 3):\n", " U[row,col]=-1 \n", " R[row,col]=0 \n", " #policy_string += \" < \"\n", " elif(np.isnan(policy_matrix[row,col])): \n", " U[row,col]=0 \n", " R[row,col]=0 \n", " #policy_string += \" # \"\n", " plt.rcParams['figure.figsize'] = (4,3)\n", " if V.size==0:\n", " V=np.ones(policy_matrix.shape)\n", " for x in lstates:\n", " V[x.multi_index]=0.7\n", " plt.matshow(V,extent=[0, 4, 0, 3],vmin=0, vmax=1)\n", " else:\n", " plt.matshow(V,extent=[0, 4, 0, 3])\n", " plt.grid()\n", " X, Y = np.meshgrid(np.arange(0.5, 4.5, 1), np.arange(0.5, 3.5, 1))\n", " Q = plt.quiver(X, Y,U,R)\n", "\n", " plt.show()\n", "\n", "print_policy(policy_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## - Transition probabilities\n", "We will represent transition probabilities with a matrix where each row represent the action executed, and column represent the probability of going to one direction. Column 0 represents direction Up, Column 1 represents direction Right, Column 2 represents direction Down and Column 3 represents direction Left. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "transition_matrix = np.array([[0.8, 0.1, 0.0, 0.1],\n", " [0.1, 0.8, 0.1, 0.0],\n", " [0.0, 0.1, 0.8, 0.1],\n", " [0.1, 0.0, 0.1, 0.8]])\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# This is the declaration of the gamma for this problem\n", "gamma = 0.999\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## - Several auxiliar functions" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def prob_next_state(position,world_row=3,world_col=4):\n", " \"\"\"\n", " Given a position in the grid-world, returns a list of possible next states\n", " SORTED by direction. Notice that when bumping with walls,\n", " state in that direction remains the same. \n", " \"\"\"\n", " posible=[]\n", " for action in range(4):\n", " position_def = position\n", " if(action == 0): new_position = (position[0]-1, position[1]) #UP\n", " elif(action == 1): new_position = (position[0], position[1]+1) #RIGHT\n", " elif(action == 2): new_position = (position[0]+1, position[1]) #DOWN\n", " elif(action == 3): new_position = (position[0], position[1]-1) #LEFT\n", " else: raise ValueError('The action is not included in the action space.')\n", "\n", " #Check if the new position is a valid position\n", " if (new_position[0]>=0 and new_position[0]=0 and new_position[1]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def Print_V(world_row=3,world_col=4):\n", " graph = \"\"\n", " for row in range(world_row):\n", " row_string = \"\"\n", " for col in range(world_col):\n", " row_string += str(round(V[One_index((row,col))],4))+' '\n", " row_string += '\\n'\n", " graph += row_string \n", " print(graph) \n", "\n", "Print_V()\n", "\n", "def Return_V(world_row=3,world_col=4):\n", " graph = []\n", " for row in range(world_row):\n", " row_l = []\n", " for col in range(world_col):\n", " row_l.append(V[One_index((row,col))])\n", " graph.append(row_l)\n", " return graph \n", "\n", "plt.matshow(Return_V(),extent=[0, 4, 0, 3])\n", "plt.grid()\n", "plt.show()\n", "\n", "V_exact = Return_V()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2- Evaluating the policy using value iteration\n", "\n", "" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0.84881224 0.90630541 0.95748948 0. ]\n", " [ 0.79776395 0. 0.69906187 0. ]\n", " [ 0.74042274 0.68950862 0.64536082 0.4219169 ]]\n" ] } ], "source": [ "error=[]\n", "V = np.zeros((3,4))\n", "delta=1000\n", "while delta> 0.00001:\n", " old_V = V.copy()\n", " delta = 0\n", " state = np.nditer(state_matrix, flags=['multi_index'])\n", " while not state.finished:\n", " action = policy_matrix[state.multi_index]\n", " if state[0]==0:\n", " acum=0\n", " for x in zip(prob_next_state(state.multi_index,3,4),transition_matrix[int(action),:]):\n", " acum = acum + x[1]*(reward_matrix[x[0]]+gamma*old_V[x[0]])\n", " V[state.multi_index] = acum\n", " state.iternext()\n", " error.append(np.sum(V_exact - V))\n", " delta = np.max(abs(V_exact - V))\n", "print(V)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot how error decreases with number of iterations" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(error)\n", "plt.ylabel('sum of errors for all states')\n", "plt.xlabel('iterations')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" } }, "nbformat": 4, "nbformat_minor": 2 }