{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Applying Tabular Methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Comparing Q Learning And Double Q Learning - Deterministic" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import Statements" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gym as gym\n", "from gym import spaces\n", "import matplotlib.pyplot as plt\n", "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n", "import numpy as np\n", "import pickle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Frozen Lake Environment\n", "\n", "Description: The agent is on a frozen lake and must navigate to the goal while avoiding holes and collecting Gems." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class FrozenLakeEnv(gym.Env):\n", " metadata = {'render.modes': []}\n", " def __init__(self, gamma=0.9, alpha=0.1, max_timestamp=10):\n", " self.gamma = gamma\n", " self.alpha = alpha\n", "\n", " self.obs_space = spaces.Discrete(16)\n", " self.action_space = spaces.Discrete(4)\n", " self.max_timestamp = max_timestamp\n", " self.timestep = 0\n", "\n", " ## STATES SET #############\n", " self.state = np.zeros((4, 4))\n", " self.myskater = np.asarray([0, 0])\n", " self.goal_loc = np.asarray([3, 3])\n", " self.gem_loc = [np.asarray([0, 2]), np.asarray([3, 2])]\n", " self.hole_loc = [np.asarray([1, 3]), np.asarray([2, 0])]\n", "\n", " self.state[tuple(self.myskater)] = 0.2\n", " self.state[tuple(self.goal_loc)] = 0.8\n", " for pos in self.gem_loc:\n", " self.state[tuple(pos)] = 0.5\n", " for pos in self.hole_loc:\n", " self.state[tuple(pos)] = 0.4\n", " ##########################\n", " \n", " self.prev_state = np.zeros((4, 4))\n", " self.prev_action = None\n", " self.penalty_counter = 0 \n", " self.flag_out_grid = 0 \n", "\n", "\n", " def step(self, action):\n", " self.prev_state = np.copy(self.state) \n", " self.prev_action = action \n", " self.flag_out_grid = 0 \n", "\n", " ## ACTIONS SET #############\n", " if action == 0: # moves right\n", " self.myskater[0] += 1\n", " elif action == 1: # moves left\n", " self.myskater[0] -= 1\n", " elif action == 2: # moves up\n", " self.myskater[1] += 1 \n", " elif action == 3: # moves down\n", " self.myskater[1] -= 1\n", " ##########################\n", " \n", " self.myskater = np.clip(self.myskater, 0, 3)\n", "\n", " # If the agent is in the same position as the previous step, choose a different action\n", " prev_state_positions = np.argwhere(self.prev_state == 0.2)\n", " if len(prev_state_positions) > 0 and np.array_equal(self.myskater, prev_state_positions[0]):\n", " while action == self.prev_action:\n", " action = self.action_space.sample()\n", " self.flag_out_grid = 1\n", "\n", " self.state = np.zeros((4, 4))\n", " self.state[tuple(self.myskater)] = 0.2\n", " self.state[tuple(self.goal_loc)] = 0.8\n", " for pos in self.gem_loc:\n", " self.state[tuple(pos)] = 0.5\n", " for pos in self.hole_loc:\n", " self.state[tuple(pos)] = 0.4\n", "\n", " obs = self.state.flatten()\n", " reward = self.calculate_reward()\n", " penalty = any(np.array_equal(self.myskater, pos) for pos in self.hole_loc)\n", " if penalty:\n", " self.penalty_counter += 1 \n", " self.timestep += 1\n", "\n", " terminated = True if np.array_equal(self.myskater, self.goal_loc) else self.timestep >= self.max_timestamp\n", " truncated = True if np.any((self.myskater < 0) | (self.myskater > 3)) else False\n", " if truncated:\n", " self.flag_out_grid = 1 \n", "\n", " info = {}\n", "\n", " return self.state.flatten(), reward, terminated, truncated, info\n", "\n", " def reset(self, **kwargs):\n", " self.state = np.zeros((4, 4))\n", " self.myskater = np.asarray([0, 0])\n", " self.state[tuple(self.myskater)] = 0.2\n", " self.state[tuple(self.goal_loc)] = 0.8\n", "\n", " for pos in self.gem_loc:\n", " self.state[tuple(pos)] = 0.5\n", " for pos in self.hole_loc:\n", " self.state[tuple(pos)] = 0.4\n", "\n", " self.prev_state = np.zeros((4, 4))\n", " self.prev_action = None\n", " self.flag_out_grid = 0 \n", " \n", " obs = self.state.flatten()\n", " self.timestep = 0\n", " info = {}\n", " self.penalty_counter = 0\n", " return obs, info\n", "\n", " def calculate_reward(self):\n", " prev_myskateritions = np.argwhere(self.prev_state == 0.2)\n", " if prev_myskateritions.size == 0:\n", " prev_myskaterition = self.myskater\n", " else:\n", " prev_myskaterition = prev_myskateritions[0]\n", "\n", " # Calculating distance to goal before and after the step\n", " prev_distance_to_goal = np.linalg.norm(self.goal_loc - prev_myskaterition)\n", " current_distance_to_goal = np.linalg.norm(self.goal_loc - self.myskater)\n", "\n", "\n", " ## REWARDS SET #############\n", " if np.array_equal(self.myskater, self.goal_loc):\n", " reward = 10 # Positive reward for reaching goal\n", " elif np.array_equal(self.myskater, self.hole_loc[0]):\n", " reward = -5 # negative reward for reaching holes 1\n", " elif np.array_equal(self.myskater, self.hole_loc[1]):\n", " reward = -6 # negative reward for reaching holes 2\n", " elif np.array_equal(self.myskater, self.gem_loc[0]):\n", " reward = 5 # positive reward for reaching gems 1\n", " elif np.array_equal(self.myskater, self.gem_loc[1]):\n", " reward = 6 # positive reward for reaching gems 2\n", " elif current_distance_to_goal < prev_distance_to_goal:\n", " reward = 1 # Positive reward for moving closer to goal\n", " elif current_distance_to_goal > prev_distance_to_goal:\n", " reward = -1 # Negative reward for moving away to goal\n", " else:\n", " reward = -0.1 # Slight negative reward for no change\n", " ##########################\n", " \n", " return reward\n", " \n", " def get_penalty_count(self):\n", " return self.penalty_counter\n", "\n", " def render(self):\n", " fig, ax = plt.subplots()\n", " plt.title('Frozen Lake Environment')\n", "\n", " # Load and display the background image\n", " background_img = plt.imread('images/frozen_lake.jpg')\n", " ax.imshow(background_img, extent=(-0.5, 3.5, -0.5, 3.5), origin='upper')\n", "\n", " \n", " skater_img = plt.imread('images/icons8-skateboard-100.png')\n", " hole_img = plt.imread('images/icons8-hole-100.png')\n", " gem_img = plt.imread('images/icons8-gems-100.png')\n", " goal_img = plt.imread('images/icons8-flag-100.png')\n", " skater_hole_drown_img = plt.imread('images/agent_hole_drown.png')\n", " skater_gem_lottery_img = plt.imread('images/agent_gems_lottery.png')\n", " agent_flag_winner_img = plt.imread('images/agent_flag_winner.png')\n", " agent_grid_cross_img = plt.imread('images/agent_grid_cross.png')\n", "\n", " # Plot Skater\n", " myskater = self.myskater\n", " if self.flag_out_grid:\n", " skater_img = agent_grid_cross_img\n", " agent_box = AnnotationBbox(OffsetImage(skater_img, zoom=0.4), myskater, frameon=False)\n", " ax.add_artist(agent_box)\n", "\n", " # Plot Holes\n", " for hole_loc in self.hole_loc:\n", " hole_loc = hole_loc\n", " if np.array_equal(self.myskater, hole_loc):\n", " hole_img = skater_hole_drown_img\n", " else:\n", " hole_img = plt.imread('images/icons8-hole-100.png')\n", " rock_box = AnnotationBbox(OffsetImage(hole_img, zoom=0.4), hole_loc, frameon=False)\n", " ax.add_artist(rock_box)\n", "\n", " # Plot Gems\n", " for gem_loc in self.gem_loc:\n", " gem_loc = gem_loc\n", " if np.array_equal(self.myskater, gem_loc):\n", " gem_img = skater_gem_lottery_img\n", " else:\n", " gem_img = plt.imread('images/icons8-gems-100.png')\n", " battery_box = AnnotationBbox(OffsetImage(gem_img, zoom=0.4), gem_loc, frameon=False)\n", " ax.add_artist(battery_box)\n", "\n", " # Plot goal\n", " goal_loc = self.goal_loc\n", " goal_loc = self.goal_loc\n", " if np.array_equal(self.myskater, goal_loc):\n", " goal_img = agent_flag_winner_img\n", " else:\n", " goal_img = plt.imread('images/icons8-flag-100.png')\n", " goal_box = AnnotationBbox(OffsetImage(goal_img, zoom=0.4), goal_loc, frameon=False)\n", " ax.add_artist(goal_box)\n", "\n", " plt.xticks(np.arange(-0.5, 4.5, 1))\n", " plt.yticks(np.arange(-0.5, 4.5, 1))\n", " plt.gca().set_xticklabels(np.arange(-0.5, 4.5, 1))\n", " plt.gca().set_yticklabels(np.arange(-0.5, 4.5, 1))\n", " plt.show()\n", "\n", " def obs_space_to_index(self, obs):\n", " myskater = np.argwhere(obs.reshape(4, 4) == 0.2)\n", " if myskater.size == 0:\n", " return 0 \n", " return myskater[0, 0] * 4 + myskater[0, 1]\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode: 100\n", "Q-table:\n", "[[ 1.42723997 6.78999302 2.68193215 4.43591888]\n", " [ 2.16578485 3.54087052 8.75969117 0.87362594]\n", " [ 0. 0. 0. 0. ]\n", " [-1.63683016 0.22204377 1.00913149 6.96256755]\n", " [-2.77629507 0.62622993 2.11969003 -0.7488315 ]\n", " [ 0.80695219 1.65217441 2.99716171 -0.93977439]\n", " [ 1.09276051 7.67247227 -2.60672768 -0.19942972]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 1.29397071 -3.69723977]\n", " [ 1.99804948 -0.32304096 2.27838089 -0.48436502]\n", " [ 6.34819007 -1.11210861 1.39558805 0.02939428]\n", " [-0.2681253 -1.96242421 0.63446175 0.08252013]\n", " [-0.12551501 -0.06515363 0. -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.82\n", "Episode: 100, Average Steps: 9.87\n", "Episode: 200\n", "Q-table:\n", "[[ 7.76114247 14.75020672 10.29057532 9.41702406]\n", " [ 8.25202923 7.00407673 14.318215 7.86753844]\n", " [ 0. 0. 0. 0. ]\n", " [ 1.35529533 7.92267886 7.71827863 14.77420634]\n", " [-1.23055949 5.1462758 8.57466945 1.66018347]\n", " [ 1.02030019 11.67040558 3.810906 0.02750318]\n", " [ 1.97149531 14.19346777 1.00367498 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 1.71231574 -3.69723977]\n", " [ 1.99804948 0.99650451 4.53845252 -0.31770528]\n", " [10.23998695 -1.00222599 1.39558805 0.38675183]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.38\n", "Episode: 200, Average Steps: 9.95\n", "Episode: 300\n", "Q-table:\n", "[[11.72990371 17.57163019 22.31972985 16.38415418]\n", " [11.80562931 13.56446015 24.23251788 11.94251339]\n", " [ 0. 0. 0. 0. ]\n", " [ 7.47338295 17.82802859 17.36951117 24.29601776]\n", " [ 2.01645929 9.46511197 10.18029736 4.02458681]\n", " [ 0.91088794 14.40912846 5.43496824 2.88880912]\n", " [ 3.91711126 19.07928347 1.00367498 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 3.04409633 -2.21027338]\n", " [ 1.99804948 3.11979613 8.3910041 -0.31770528]\n", " [16.7191283 -1.00222599 1.39558805 1.45297177]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.25\n", "Episode: 300, Average Steps: 9.96\n", "Episode: 400\n", "Q-table:\n", "[[19.08163981 25.26514522 29.05382365 24.88959976]\n", " [16.23877087 20.76757171 31.07429527 16.68371581]\n", " [ 0. 0. 0. 0. ]\n", " [15.69939879 22.81457104 25.98643935 31.52129309]\n", " [ 2.01645929 9.46511197 15.07275018 5.85208184]\n", " [ 0.91088794 23.14122492 8.23219173 2.88880912]\n", " [ 3.91711126 26.23963573 1.00367498 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 3.04409633 -2.21027338]\n", " [ 1.99804948 3.11979613 8.3910041 -0.31770528]\n", " [16.7191283 -1.00222599 1.39558805 1.45297177]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.28\n", "Episode: 400, Average Steps: 10.0\n", "Episode: 500\n", "Q-table:\n", "[[28.0535013 30.95033359 28.87726968 29.97831582]\n", " [17.76113626 27.37436546 38.33654742 19.03959187]\n", " [ 0. 0. 0. 0. ]\n", " [20.81854895 31.79428961 27.92280058 37.76245159]\n", " [ 2.01645929 9.46511197 17.8680711 5.85208184]\n", " [ 1.35803848 28.12321075 8.23219173 4.63761884]\n", " [ 3.91711126 36.31921542 1.00367498 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 3.93319997 -2.21027338]\n", " [ 1.99804948 3.11979613 10.03760369 -0.31770528]\n", " [21.86636748 -1.00222599 1.39558805 1.45297177]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.17\n", "Episode: 500, Average Steps: 9.93\n", "Episode: 600\n", "Q-table:\n", "[[26.20356258 33.69642231 40.32572392 33.58290496]\n", " [17.76113626 27.37436546 42.94805753 21.98863199]\n", " [ 0. 0. 0. 0. ]\n", " [26.15357041 35.37290026 32.13593276 43.37934436]\n", " [ 2.01645929 9.46511197 20.74991027 5.85208184]\n", " [ 1.35803848 30.45098359 8.23219173 4.63761884]\n", " [ 3.91711126 38.38553034 5.50661414 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 3.93319997 -2.21027338]\n", " [ 1.99804948 3.11979613 10.03760369 -0.31770528]\n", " [21.86636748 -1.00222599 1.39558805 1.45297177]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.29\n", "Episode: 600, Average Steps: 10.0\n", "Episode: 700\n", "Q-table:\n", "[[29.17170987 39.23945077 44.27100087 36.78533842]\n", " [19.84489019 27.37436546 46.38175909 24.45948712]\n", " [ 0. 0. 0. 0. ]\n", " [28.13146944 39.72300782 37.95522193 46.84851913]\n", " [ 2.01645929 9.46511197 22.12668889 5.85208184]\n", " [ 1.35803848 33.90894028 8.23219173 4.63761884]\n", " [ 3.91711126 39.88096436 5.50661414 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 3.93319997 -2.21027338]\n", " [ 1.99804948 3.11979613 10.03760369 -0.31770528]\n", " [21.86636748 -1.00222599 1.39558805 1.45297177]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.05\n", "Episode: 700, Average Steps: 10.0\n", "Episode: 800\n", "Q-table:\n", "[[29.17170987 40.8264187 46.92531664 41.74577505]\n", " [19.84489019 27.37436546 49.15595132 24.45948712]\n", " [ 0. 0. 0. 0. ]\n", " [30.59593056 40.55859133 41.01500447 49.49826475]\n", " [ 2.01645929 9.46511197 22.12668889 5.85208184]\n", " [ 1.35803848 33.90894028 8.23219173 4.63761884]\n", " [ 3.91711126 39.88096436 5.50661414 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 3.93319997 -2.21027338]\n", " [ 1.99804948 3.11979613 10.03760369 -0.31770528]\n", " [21.86636748 -1.00222599 1.39558805 1.45297177]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.05\n", "Episode: 800, Average Steps: 10.0\n", "Episode: 900\n", "Q-table:\n", "[[29.86512739 40.67390223 47.58131528 39.26699296]\n", " [19.84489019 27.37436546 49.93719026 24.45948712]\n", " [ 0. 0. 0. 0. ]\n", " [33.25051423 42.31258557 41.01500447 50.17920522]\n", " [ 2.01645929 9.46511197 25.48004289 5.85208184]\n", " [ 1.35803848 37.43257001 8.23219173 4.63761884]\n", " [ 3.91711126 42.21281234 5.50661414 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 3.93319997 -2.21027338]\n", " [ 1.99804948 3.11979613 10.03760369 -0.31770528]\n", " [21.86636748 -1.00222599 1.39558805 1.45297177]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.11\n", "Episode: 900, Average Steps: 10.0\n", "Episode: 1000\n", "Q-table:\n", "[[29.16626439 40.67390223 49.11854347 41.48576772]\n", " [19.84489019 30.49336587 51.66496546 24.45948712]\n", " [ 0. 0. 0. 0. ]\n", " [33.25051423 42.31258557 41.01500447 51.67166517]\n", " [ 2.01645929 9.46511197 27.14217769 5.85208184]\n", " [ 1.35803848 39.01322785 8.23219173 4.63761884]\n", " [ 3.91711126 42.21281234 5.50661414 2.72217745]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.27226932 -0.5518115 3.93319997 -2.21027338]\n", " [ 1.99804948 3.11979613 10.03760369 -0.31770528]\n", " [21.86636748 -1.00222599 1.39558805 1.45297177]\n", " [-0.15249571 -1.96242421 0.68929249 0.08252013]\n", " [-0.12551501 -0.06515363 2.01870384 -0.81139608]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.0\n", "Episode: 1000, Average Steps: 10.0\n" ] } ], "source": [ "env_det = FrozenLakeEnv()\n", "\n", "epsilon = 1.0 # Initial exploration rate\n", "epsilon_min = 0.01 # Minimum exploration rate\n", "gamma = 0.95 # Discount factor\n", "alpha = 0.15 # Learning rate\n", "decay_rate = 0.995 # Epsilon decay rate per episode\n", "total_episodes = 1000 \n", "max_timestamp = 10\n", "qt = np.zeros((env_det.obs_space.n, env_det.action_space.n))\n", "\n", "rewards_epi = []\n", "epsilon_values = []\n", "steps_per_episode = []\n", "penalties_per_episode = []\n", "\n", "final_state = None\n", "for episode in range(total_episodes):\n", " state, _ = env_det.reset()\n", " state_index = env_det.obs_space_to_index(state)\n", " total_rewards = 0\n", " total_steps = 0 \n", " action = env_det.action_space.sample() if np.random.uniform(0, 1) < epsilon else np.argmax(qt[state_index])\n", "\n", " while True:\n", " next_state, reward, terminated, truncated, _ = env_det.step(action)\n", " total_steps += 1 \n", " next_strt_idx = env_det.obs_space_to_index(next_state)\n", " next_action = env_det.action_space.sample() if np.random.uniform(0, 1) < epsilon else np.argmax(qt[next_strt_idx])\n", " qt[state_index, action] = qt[state_index, action] + alpha * (reward + gamma * qt[next_strt_idx, next_action] - qt[state_index, action])\n", " state_index, action = next_strt_idx, next_action\n", " total_rewards += reward\n", " \n", " if terminated or truncated:\n", " break\n", " \n", " penalties_per_episode.append(env_det.get_penalty_count()) \n", " \n", " # Q-table for every 100 episodes\n", " if (episode + 1) % 100 == 0:\n", " print(f\"Episode: {episode + 1}\")\n", " print(\"Q-table:\")\n", " print(qt)\n", " avg_penalty = np.mean(penalties_per_episode[-100:])\n", " print(f\"Average Penalties in Last 100 Episodes: {avg_penalty}\")\n", "\n", " epsilon = max(epsilon_min, epsilon * decay_rate)\n", " epsilon_values.append(epsilon)\n", " rewards_epi.append(total_rewards)\n", " steps_per_episode.append(total_steps)\n", "\n", " if (episode + 1) % 100 == 0:\n", " average_steps = np.mean(steps_per_episode[-100:])\n", " print(f\"Episode: {episode + 1}, Average Steps: {average_steps}\")\n", "\n", " if episode == total_episodes - 1:\n", " final_state = env_det.state\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Episode: 100\n", "Q-table 1:\n", "[[ 4.44203487 8.18633458 11.85416368 8.64172487]\n", " [ 3.46093516 9.34344823 12.0930265 6.09384082]\n", " [ 0. 0. 0. 0. ]\n", " [ 1.20764956 2.26181567 5.08217587 9.07809385]\n", " [-0.89981292 2.7856339 5.11531748 2.19007779]\n", " [ 1.77219039 5.87305158 5.09193859 0.84162314]\n", " [ 0.60248264 8.51872763 -0.14901717 1.71720962]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0.15 -0.128625 3.88382836 -0.7673582 ]\n", " [ 9.48435053 0. 0.2775 -0.01879452]\n", " [ 2.6021783 0. -0.015 0. ]\n", " [ 0. -0.69493363 0.7068246 0. ]\n", " [-0.015 0.06291875 3.25998486 -0.15 ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Q-table 2:\n", "[[ 3.9912651 8.14661117 9.98783913 8.92610903]\n", " [ 5.21168058 8.75845551 14.36755081 4.5665347 ]\n", " [ 0. 0. 0. 0. ]\n", " [ 1.18601867 0.87867675 2.18465028 11.31758028]\n", " [-1.730722 1.47031361 4.64908018 0.32229384]\n", " [ 2.32121037 6.5567668 4.36935714 1.28520755]\n", " [ 4.74265278 8.27745111 0. 0.70413959]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 1.24060633 1.07199756 3.82046765 -0.36473917]\n", " [ 5.61270839 0.86559191 0.2775 0.40554839]\n", " [ 3.18921832 0.21453231 -0.015 1.09363834]\n", " [ 0. -0.9 0. 0. ]\n", " [-0.015 -0.15 3.58739299 0. ]\n", " [ 0. 0. 0. 0. ]\n", " [ 0. 0. 0. 0. ]]\n", "Average Penalties in Last 100 Episodes: 0.63\n", "Episode: 100, Average Steps: 9.84\n", "Episode: 200\n", "Q-table 1:\n", "[[ 1.39789116e+01 1.96866109e+01 2.52138339e+01 2.01193086e+01]\n", " [ 9.28079170e+00 1.54739286e+01 2.70077807e+01 1.42619102e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 8.55149103e+00 1.71647164e+01 1.74316445e+01 2.77465107e+01]\n", " [ 1.45012419e+00 6.05798637e+00 1.12256552e+01 4.32342283e+00]\n", " [ 2.96736561e+00 1.87194734e+01 5.09193859e+00 1.83468459e+00]\n", " [ 2.19244396e+00 2.03509966e+01 2.71113559e+00 5.12336019e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 4.56321624e+00 -7.67358201e-01]\n", " [ 1.41238669e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 6.15051248e+00 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 1.27170507e+01 1.94081929e+01 2.54761991e+01 1.84004083e+01]\n", " [ 1.20338211e+01 1.82021705e+01 2.65039382e+01 1.54584383e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.10971094e+01 1.73432149e+01 1.68476021e+01 2.78658386e+01]\n", " [ 1.50845999e+00 2.74039284e+00 1.12629796e+01 3.23813277e+00]\n", " [ 3.28745820e+00 1.49882581e+01 7.89607262e+00 2.03866494e+00]\n", " [ 6.78695050e+00 2.08777560e+01 0.00000000e+00 4.18041181e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 3.34255858e+00 5.85435392e+00 7.31260487e-01]\n", " [ 5.61270839e+00 8.65591906e-01 7.56685408e-01 7.48161673e-01]\n", " [ 3.18921832e+00 2.14532309e-01 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.52\n", "Episode: 200, Average Steps: 9.95\n", "Episode: 300\n", "Q-table 1:\n", "[[ 2.26445916e+01 3.14530691e+01 3.83381786e+01 3.15398629e+01]\n", " [ 1.47152601e+01 2.98881625e+01 3.79091435e+01 2.31309389e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 2.00264342e+01 2.70831729e+01 2.88605051e+01 3.78573171e+01]\n", " [ 5.50892226e+00 9.03139984e+00 1.31272139e+01 4.32342283e+00]\n", " [ 4.69293056e+00 2.21794945e+01 1.08643794e+01 1.83468459e+00]\n", " [ 4.91801938e+00 3.39018279e+01 8.85944990e+00 6.61130804e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 8.94323321e+00 3.05046158e+00]\n", " [ 2.45573697e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 1.10870317e+01 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 2.05111990e+01 3.28298515e+01 3.53707149e+01 3.21698348e+01]\n", " [ 1.41227368e+01 2.19960490e+01 4.03850398e+01 2.32262421e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 2.24377565e+01 3.24428942e+01 3.02211916e+01 4.10078116e+01]\n", " [ 5.84522529e+00 2.74039284e+00 1.64316830e+01 3.23813277e+00]\n", " [ 3.28745820e+00 2.79853867e+01 9.76167874e+00 2.03866494e+00]\n", " [ 8.32831500e+00 3.36132167e+01 6.29755450e+00 6.18756684e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 5.47539160e+00 7.13885187e+00 7.31260487e-01]\n", " [ 1.33963700e+01 8.65591906e-01 2.37308461e+00 1.13619574e+00]\n", " [ 1.37092719e+01 2.65324355e+00 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.52\n", "Episode: 300, Average Steps: 9.75\n", "Episode: 400\n", "Q-table 1:\n", "[[ 2.23728574e+01 4.25559783e+01 4.32527684e+01 3.89853674e+01]\n", " [ 1.92874612e+01 3.37270276e+01 4.60032090e+01 2.57017572e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 2.73063333e+01 3.78448573e+01 3.88628590e+01 4.61972997e+01]\n", " [ 5.50892226e+00 9.03139984e+00 2.12334452e+01 6.12195437e+00]\n", " [ 5.15627737e+00 2.93174891e+01 1.08643794e+01 3.75099673e+00]\n", " [ 6.23929919e+00 4.01858516e+01 8.85944990e+00 6.61130804e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 1.02937762e+01 3.05046158e+00]\n", " [ 2.45573697e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 2.14007315e+01 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 2.93540529e+01 4.16097620e+01 4.38043046e+01 3.87516158e+01]\n", " [ 1.73454108e+01 2.76210764e+01 4.57014336e+01 2.97070977e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.12790262e+01 3.95241788e+01 3.88617187e+01 4.57746771e+01]\n", " [ 5.84522529e+00 7.38694553e+00 1.72775085e+01 3.23813277e+00]\n", " [ 3.28745820e+00 3.33714387e+01 9.76167874e+00 2.03866494e+00]\n", " [ 8.32831500e+00 3.88570371e+01 6.29755450e+00 6.18756684e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 5.47539160e+00 7.13885187e+00 7.31260487e-01]\n", " [ 2.21061125e+01 8.65591906e-01 2.37308461e+00 1.13619574e+00]\n", " [ 1.37092719e+01 2.65324355e+00 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.28\n", "Episode: 400, Average Steps: 9.88\n", "Episode: 500\n", "Q-table 1:\n", "[[ 2.97299837e+01 4.32640943e+01 4.54112016e+01 4.01884464e+01]\n", " [ 2.80555288e+01 3.69113158e+01 4.83684979e+01 2.57017572e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.42281553e+01 4.26250449e+01 4.21720334e+01 4.84591105e+01]\n", " [ 5.50892226e+00 9.03139984e+00 2.94784092e+01 6.12195437e+00]\n", " [ 5.55012215e+00 3.66115669e+01 1.08643794e+01 3.75099673e+00]\n", " [ 6.23929919e+00 4.47266227e+01 8.85944990e+00 6.61130804e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 1.20498308e+01 3.05046158e+00]\n", " [ 2.45573697e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 2.14007315e+01 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 3.08508623e+01 4.54806344e+01 4.56512756e+01 4.27120890e+01]\n", " [ 2.05383824e+01 3.58378638e+01 4.75922680e+01 3.39970292e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.44228431e+01 4.10364374e+01 4.18445104e+01 4.77250416e+01]\n", " [ 5.84522529e+00 7.38694553e+00 2.10814471e+01 3.23813277e+00]\n", " [ 3.28745820e+00 4.09111023e+01 9.76167874e+00 2.03866494e+00]\n", " [ 8.32831500e+00 4.12351486e+01 6.29755450e+00 6.18756684e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 5.47539160e+00 7.13885187e+00 7.31260487e-01]\n", " [ 2.61573828e+01 8.65591906e-01 2.37308461e+00 1.13619574e+00]\n", " [ 1.37092719e+01 2.65324355e+00 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.27\n", "Episode: 500, Average Steps: 9.97\n", "Episode: 600\n", "Q-table 1:\n", "[[ 3.66253754e+01 4.60894827e+01 4.81774366e+01 4.27744128e+01]\n", " [ 2.80555288e+01 3.97324613e+01 5.02251400e+01 2.57017572e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.51229233e+01 4.40886438e+01 4.45511246e+01 5.01028707e+01]\n", " [ 5.50892226e+00 9.03139984e+00 2.94784092e+01 6.12195437e+00]\n", " [ 5.55012215e+00 3.66115669e+01 1.08643794e+01 3.75099673e+00]\n", " [ 6.23929919e+00 4.72274625e+01 8.85944990e+00 6.61130804e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 1.20498308e+01 3.05046158e+00]\n", " [ 2.45573697e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 2.14007315e+01 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 3.68575628e+01 4.69064466e+01 4.75294278e+01 4.37735633e+01]\n", " [ 2.05383824e+01 3.74987586e+01 5.04889643e+01 3.39970292e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.65729931e+01 4.18107866e+01 4.26152008e+01 5.06714372e+01]\n", " [ 5.84522529e+00 7.38694553e+00 2.10814471e+01 3.23813277e+00]\n", " [ 3.28745820e+00 4.09111023e+01 9.76167874e+00 2.03866494e+00]\n", " [ 8.32831500e+00 4.58137544e+01 6.29755450e+00 6.18756684e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 5.47539160e+00 7.13885187e+00 7.31260487e-01]\n", " [ 2.61573828e+01 8.65591906e-01 2.37308461e+00 1.13619574e+00]\n", " [ 1.37092719e+01 2.65324355e+00 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.1\n", "Episode: 600, Average Steps: 10.0\n", "Episode: 700\n", "Q-table 1:\n", "[[ 3.99112365e+01 4.79142613e+01 4.89639003e+01 4.39191551e+01]\n", " [ 2.99472539e+01 3.97324613e+01 5.14299087e+01 2.86071210e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.64898533e+01 4.46770306e+01 4.56438034e+01 5.14963289e+01]\n", " [ 5.50892226e+00 9.03139984e+00 3.10364799e+01 8.19276742e+00]\n", " [ 5.55012215e+00 3.95287694e+01 1.08643794e+01 3.75099673e+00]\n", " [ 6.23929919e+00 4.87118279e+01 8.85944990e+00 6.61130804e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 1.20498308e+01 3.05046158e+00]\n", " [ 2.45573697e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 2.14007315e+01 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 3.90929524e+01 4.68227773e+01 4.90281279e+01 4.43472889e+01]\n", " [ 2.30488581e+01 3.89892761e+01 5.15350252e+01 3.39970292e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.72489956e+01 4.52933095e+01 4.50922471e+01 5.15578267e+01]\n", " [ 5.84522529e+00 7.38694553e+00 2.10814471e+01 3.23813277e+00]\n", " [ 3.28745820e+00 4.17547682e+01 9.76167874e+00 2.03866494e+00]\n", " [ 8.32831500e+00 4.76911695e+01 6.29755450e+00 6.18756684e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 5.47539160e+00 7.13885187e+00 7.31260487e-01]\n", " [ 2.61573828e+01 8.65591906e-01 2.37308461e+00 1.13619574e+00]\n", " [ 1.37092719e+01 2.65324355e+00 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.08\n", "Episode: 700, Average Steps: 10.0\n", "Episode: 800\n", "Q-table 1:\n", "[[ 4.17044501e+01 4.79830072e+01 4.61690734e+01 4.39191551e+01]\n", " [ 3.15552203e+01 4.11301428e+01 5.11292934e+01 2.86071210e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.71680177e+01 4.46770306e+01 4.66941219e+01 5.10261617e+01]\n", " [ 5.50892226e+00 9.03139984e+00 3.24810624e+01 8.19276742e+00]\n", " [ 5.55012215e+00 4.20110841e+01 1.08643794e+01 3.75099673e+00]\n", " [ 6.23929919e+00 4.91809771e+01 8.85944990e+00 6.61130804e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 1.20498308e+01 3.05046158e+00]\n", " [ 2.45573697e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 2.14007315e+01 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 3.98255381e+01 4.69726127e+01 4.84325782e+01 4.63036235e+01]\n", " [ 2.30488581e+01 3.89892761e+01 5.23685857e+01 3.39970292e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.80402988e+01 4.58619543e+01 4.56897334e+01 5.21768635e+01]\n", " [ 5.84522529e+00 7.38694553e+00 2.10814471e+01 3.23813277e+00]\n", " [ 3.28745820e+00 4.17547682e+01 9.76167874e+00 2.03866494e+00]\n", " [ 8.32831500e+00 4.94193299e+01 6.29755450e+00 6.18756684e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 5.47539160e+00 7.13885187e+00 7.31260487e-01]\n", " [ 2.61573828e+01 8.65591906e-01 2.37308461e+00 1.13619574e+00]\n", " [ 1.37092719e+01 2.65324355e+00 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.09\n", "Episode: 800, Average Steps: 10.0\n", "Episode: 900\n", "Q-table 1:\n", "[[ 4.17044501e+01 4.79830072e+01 4.95636716e+01 4.50528430e+01]\n", " [ 3.15552203e+01 4.23323656e+01 5.21406707e+01 2.86071210e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.71680177e+01 4.60039692e+01 4.66941219e+01 5.21684117e+01]\n", " [ 5.50892226e+00 9.03139984e+00 3.37089575e+01 8.19276742e+00]\n", " [ 5.55012215e+00 4.29461657e+01 1.08643794e+01 3.75099673e+00]\n", " [ 6.23929919e+00 4.91809771e+01 8.85944990e+00 6.61130804e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 1.20498308e+01 3.05046158e+00]\n", " [ 2.45573697e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 2.14007315e+01 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 3.86302588e+01 4.64908138e+01 4.95894967e+01 4.63036235e+01]\n", " [ 2.30488581e+01 3.89892761e+01 5.20788060e+01 3.39970292e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.80402988e+01 4.64008384e+01 4.64737471e+01 5.20272653e+01]\n", " [ 5.84522529e+00 7.38694553e+00 2.10814471e+01 3.23813277e+00]\n", " [ 3.28745820e+00 4.17547682e+01 9.76167874e+00 2.03866494e+00]\n", " [ 8.32831500e+00 4.94193299e+01 6.29755450e+00 6.18756684e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 5.47539160e+00 7.13885187e+00 7.31260487e-01]\n", " [ 2.61573828e+01 8.65591906e-01 2.37308461e+00 1.13619574e+00]\n", " [ 1.37092719e+01 2.65324355e+00 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.0\n", "Episode: 900, Average Steps: 10.0\n", "Episode: 1000\n", "Q-table 1:\n", "[[ 4.24760371e+01 4.79830072e+01 4.93927507e+01 4.56666637e+01]\n", " [ 3.15552203e+01 4.23323656e+01 5.20419674e+01 3.11722406e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.79469883e+01 4.60039692e+01 4.66941219e+01 5.20375974e+01]\n", " [ 5.50892226e+00 9.03139984e+00 3.37089575e+01 8.19276742e+00]\n", " [ 5.55012215e+00 4.29461657e+01 1.08643794e+01 3.75099673e+00]\n", " [ 6.23929919e+00 4.91809771e+01 8.85944990e+00 6.61130804e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.73453746e+00 1.90337244e+00 1.20498308e+01 3.05046158e+00]\n", " [ 2.45573697e+01 1.43401394e+00 2.77500000e-01 -1.87945243e-02]\n", " [ 2.14007315e+01 0.00000000e+00 -1.50000000e-02 0.00000000e+00]\n", " [ 0.00000000e+00 -6.94933632e-01 1.47980765e+00 0.00000000e+00]\n", " [ 7.01256733e-01 7.37726371e-01 3.25998486e+00 -1.50000000e-01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Q-table 2:\n", "[[ 3.86302588e+01 4.64908138e+01 4.95021316e+01 4.63036235e+01]\n", " [ 2.30488581e+01 3.89892761e+01 5.18368218e+01 3.39970292e+01]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 3.80402988e+01 4.67163197e+01 4.71785032e+01 5.19302899e+01]\n", " [ 5.84522529e+00 7.38694553e+00 2.10814471e+01 3.23813277e+00]\n", " [ 3.28745820e+00 4.17547682e+01 9.76167874e+00 2.03866494e+00]\n", " [ 8.32831500e+00 4.97547528e+01 6.29755450e+00 6.18756684e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 1.24060633e+00 5.47539160e+00 7.13885187e+00 7.31260487e-01]\n", " [ 2.61573828e+01 8.65591906e-01 2.37308461e+00 1.13619574e+00]\n", " [ 1.37092719e+01 2.65324355e+00 -1.50000000e-02 1.09363834e+00]\n", " [ 0.00000000e+00 -9.00000000e-01 6.14547842e-01 0.00000000e+00]\n", " [-1.50000000e-02 2.75945541e-01 7.35869343e+00 -4.92774940e-02]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]\n", " [ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]]\n", "Average Penalties in Last 100 Episodes: 0.05\n", "Episode: 1000, Average Steps: 10.0\n" ] } ], "source": [ "env_det = FrozenLakeEnv()\n", "epsilon = 1.0 \n", "epsilon_min = 0.01 \n", "gamma = 0.95 \n", "alpha = 0.15 \n", "decay_rate = 0.995 \n", "total_episodes = 1000\n", "max_timestamp = 10\n", "\n", "qt1 = np.zeros((env_det.obs_space.n, env_det.action_space.n)) # Q-table 1 initialization\n", "qt2 = np.zeros((env_det.obs_space.n, env_det.action_space.n)) # Q-table 2 initialization\n", "\n", "rewards_epi = []\n", "epsilon_values = []\n", "steps_per_episode = []\n", "penalties_per_episode = []\n", "\n", "final_state = None\n", "\n", "for episode in range(total_episodes):\n", " state, _ = env_det.reset()\n", " state_index = env_det.obs_space_to_index(state)\n", " total_rewards = 0\n", " total_steps = 0\n", "\n", " while True:\n", " total_steps += 1\n", " action = env_det.action_space.sample() if np.random.uniform(0, 1) < epsilon else np.argmax((qt1[state_index] + qt2[state_index]) / 2)\n", " next_state, reward, terminated, truncated, _ = env_det.step(action)\n", " next_strt_idx = env_det.obs_space_to_index(next_state)\n", " if np.random.uniform(0, 1) < 0.5:\n", " qt1[state_index, action] += alpha * (reward + gamma * qt2[next_strt_idx, np.argmax(qt1[next_strt_idx])] - qt1[state_index, action])\n", " else:\n", " qt2[state_index, action] += alpha * (reward + gamma * qt1[next_strt_idx, np.argmax(qt2[next_strt_idx])] - qt2[state_index, action])\n", " state_index = next_strt_idx\n", " total_rewards += reward\n", " if terminated or truncated or total_steps >= max_timestamp:\n", " break\n", "\n", " penalties_per_episode.append(env_det.get_penalty_count())\n", " if (episode + 1) % 100 == 0:\n", " print(f\"Episode: {episode + 1}\")\n", " print(\"Q-table 1:\")\n", " print(qt1)\n", " print(\"Q-table 2:\")\n", " print(qt2)\n", " avg_penalty = np.mean(penalties_per_episode[-100:])\n", " print(f\"Average Penalties in Last 100 Episodes: {avg_penalty}\")\n", "\n", " epsilon = max(epsilon_min, epsilon * decay_rate)\n", " epsilon_values.append(epsilon)\n", " rewards_epi.append(total_rewards)\n", " steps_per_episode.append(total_steps)\n", "\n", " if (episode + 1) % 100 == 0:\n", " average_steps = np.mean(steps_per_episode[-100:])\n", " print(f\"Episode: {episode + 1}, Average Steps: {average_steps}\")\n", "\n", " if episode == total_episodes - 1:\n", " final_state = env_det.state" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Q-Learning Training Loop\n", "q_rewards = []\n", "for episode in range(total_episodes):\n", " state, _ = env_det.reset()\n", " state_index = env_det.obs_space_to_index(state)\n", " total_rewards = 0\n", " total_steps = 0\n", " action = env_det.action_space.sample() if np.random.uniform(0, 1) < epsilon else np.argmax(qt[state_index])\n", " while True:\n", " next_state, reward, terminated, truncated, _ = env_det.step(action)\n", " total_steps += 1 \n", " next_strt_idx = env_det.obs_space_to_index(next_state)\n", " next_action = env_det.action_space.sample() if np.random.uniform(0, 1) < epsilon else np.argmax(qt[next_strt_idx])\n", " qt[state_index, action] = qt[state_index, action] + alpha * (reward + gamma * qt[next_strt_idx, next_action] - qt[state_index, action])\n", " state_index, action = next_strt_idx, next_action\n", " total_rewards += reward\n", " if terminated or truncated:\n", " break\n", " q_rewards.append(total_rewards)\n", "\n", "\n", "# Double Q-learning Training Loop\n", "double_q_rewards = []\n", "for episode in range(total_episodes):\n", " state, _ = env_det.reset()\n", " state_index = env_det.obs_space_to_index(state)\n", " total_rewards = 0\n", " total_steps = 0\n", " while True:\n", " total_steps += 1\n", " action = env_det.action_space.sample() if np.random.uniform(0, 1) < epsilon else np.argmax((qt1[state_index] + qt2[state_index]) / 2)\n", " next_state, reward, terminated, truncated, _ = env_det.step(action)\n", " next_strt_idx = env_det.obs_space_to_index(next_state)\n", " if np.random.uniform(0, 1) < 0.5:\n", " qt1[state_index, action] += alpha * (reward + gamma * qt2[next_strt_idx, np.argmax(qt1[next_strt_idx])] - qt1[state_index, action])\n", " else:\n", " qt2[state_index, action] += alpha * (reward + gamma * qt1[next_strt_idx, np.argmax(qt2[next_strt_idx])] - qt2[state_index, action])\n", " state_index = next_strt_idx\n", " total_rewards += reward\n", " if terminated or truncated or total_steps >= max_timestamp:\n", " break\n", " double_q_rewards.append(total_rewards)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(q_rewards, label='Q-Learning')\n", "plt.plot(double_q_rewards, label='Double Q-learning')\n", "plt.xlabel('Episode')\n", "plt.ylabel('Total Reward')\n", "plt.legend()\n", "plt.title('Comparison of Q-Learning and Double Q-learning - Deterministic')\n", "plt.show()" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }