{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Deep Policy Gradient (PG)\n", "\n", "\n", "In this notebook, our goal is to implement the REINFORCE algorithm for policy gradient using [JAX](https://jax.readthedocs.io/en/latest/). We will apply this RL algorithm to control a single quantum bit of information (qubit). \n", "\n", "## The REINFROCE Algorithm\n", "\n", "The reinforcement learning objective $J$ is the expected total return, following the policy $\\pi$. If the transition probability is denoted by $p(s'|s,a)$, and the initial state distribution is $p(s_0)$, the probability for a trajectory $\\tau = (s_0,a_0,r_1,s_1,a_1,\\dots,s_{T-1},a_{T-1},r_T,s_T)$ to occur can be written as\n", "\n", "$$\n", "P_\\pi(\\tau) = p(s_0)\\prod_{t=1}^T \\pi(a_t|s_t)p(s_{t+1}|s_t,a_t). \n", "$$\n", "\n", "The RL ojbective then takes the form\n", "\n", "$$\n", "J = \\mathrm{E}_{\\tau\\sim P_\\pi} \\left[ G(\\tau) | S_{t=0}=s_0 \\right],\\quad G(\\tau)=\\sum_{t=1}^T r(s_t,a_t).\n", "$$\n", "\n", "Policy gradient methods in RL approximate directly the policy $\\pi\\approx\\pi_\\theta$ using a variational ansatz, parametrized by the unknown parameters $\\theta$. The goal is then to find those optimal parameters $\\theta$, which optimize the RL objective $J(\\theta)$. To define an update rule for $\\theta$, we may use gradient ascent. This requires us to evaluate the gradient of the RL objective w.r.t. the parameters $\\theta$:\n", "\n", "$$\n", "\\nabla_\\theta J(\\theta) = \\nabla_\\theta \\mathrm{E}_{\\tau\\sim P_\\pi} \\left[ \\sum_{t=1}^T r(s_t,a_t) | S_{t=0}=s_0 \\right] \n", "= \\int\\mathrm{d}\\tau \\nabla_\\theta P_{\\pi_\\theta}(\\tau) G(\\tau).\n", "$$\n", "In a model-free setting, we don't have access to the transition probabilities $p(s'|s,a)$ and this requires us to be able to estimate the gradients from samples. This can be accomplished by noticing that $\\nabla_\\theta P_{\\pi_\\theta} = P_{\\pi_\\theta} \\nabla_\\theta \\log P_{\\pi_\\theta}$ (almost everywhere, i.e. up to a set of measure zero):\n", "\n", "$$\n", "\\nabla_\\theta J(\\theta) = \\int\\mathrm{d}\\tau \\nabla_\\theta P_{\\pi_\\theta}(\\tau) G(\\tau) = \\int\\mathrm{d}\\tau P_{\\pi_\\theta}(\\tau) \\nabla_\\theta \\log P_{\\pi_\\theta}(\\tau) G(\\tau) = \\mathrm{E}_{\\tau\\sim P_\\pi} \\left[\\nabla_\\theta \\log P_{\\pi_\\theta}(\\tau) G(\\tau)\\right].\n", "$$\n", "Since the initial state distribution and the transition proabilities are independent of $\\theta$, using the definition of $P_{\\pi_\\theta}$, we see that $\\nabla_\\theta P_{\\pi_\\theta}(\\tau) = \\nabla_\\theta \\pi_\\theta(\\tau)$ where $\\pi_\\theta(\\tau) = \\prod_{t=1}^T \\pi(a_t|s_t)$. \n", "\n", "We can now use MC to estimate the gradients directly from a sample of trajectories $\\{\\tau_j\\}_{j=1}^N$:\n", "$$\n", "\\nabla_\\theta J(\\theta) = \\mathrm{E}_{\\tau\\sim P_\\pi} \\left[\\nabla_\\theta \\log P_{\\pi_\\theta}(\\tau) G(\\tau)\\right]\n", "\\approx \\frac{1}{N}\\sum_{j=1}^N \\nabla_\\theta \\log \\pi_\\theta(\\tau_j) G(\\tau_j).\n", "$$\n", "\n", "In class, we discussed problems that arise due to large the variance of the gradient estimate. In particular, we showed that one can use causality and a baseline to reduce variance. The PG update then rakes the form\n", "\n", "$$\n", "\\nabla_\\theta J(\\theta)\n", "\\approx \\frac{1}{N}\\sum_{j=1}^N \\sum_{t=1}^T \\nabla_\\theta \\log \\pi_\\theta(a^j_t|s^j_t) \\left[\\sum_{t'=t}^T r(a^j_{t'}|s^j_{t'})) - b\\right].\n", "$$\n", "The corresponding gradient ascent update rule reads as\n", "\n", "$$\n", "\\theta \\leftarrow \\theta + \\alpha \\nabla_\\theta J(\\theta),\n", "$$\n", "for some step size (or learning rate) $\\alpha$. \n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Qubit Environment\n", "\n", "Let us recall the qubit environment we defined in Notebook 2.\n", "\n", "### Basic Definitions\n", "\n", "The state of a qubit $|\\psi\\rangle\\in\\mathbb{C}^2$ is modeled by a two-dimensional complex-valued vector with unit norm: $\\langle\\psi|\\psi\\rangle:=\\sqrt{|\\psi_1|^2+|\\psi_2|^2}=1$. Every qubit state is uniquely described by two angles $\\theta\\in[0,\\pi]$ and $\\varphi\\in[0,2\\pi)$:\n", "\n", "\\begin{eqnarray}\n", "|\\psi\\rangle=\n", "\\begin{pmatrix}\n", "\\psi_1 \\\\ \\psi_2\n", "\\end{pmatrix}=\n", "\\mathrm{e}^{i\\alpha}\n", "\\begin{pmatrix}\n", "\\cos\\frac{\\theta}{2} \\\\\n", "\\mathrm{e}^{i\\varphi}\\sin\\frac{\\theta}{2}\n", "\\end{pmatrix}\n", "\\end{eqnarray}\n", "\n", "The overall phase $\\alpha$ of a single quantum state has no physical meaning.\n", "Thus, any qubit state can be pictured as an arrow on the unit sphere (called the Bloch sphere) with coordinates $(\\theta,\\varphi)$. \n", "\n", "To operate on qubits, we use quantum gates. Quantum gates are represented as unitary transformations $U\\in \\mathrm{U(2)}$, where $\\mathrm{U(2)}$ is the unitary group. Gates act on qubit states by matrix multiplication to transform an input state $|\\psi\\rangle$ to the output state $|\\psi'\\rangle$: $|\\psi'\\rangle=U|\\psi\\rangle$. For this problem, we consider four gates\n", "\n", "\$$\n", "U_0=\\boldsymbol{1},\\qquad \n", "U_x=\\mathrm{exp}(-i\\delta t \\sigma^x/2),\\qquad\n", "U_y=\\mathrm{exp}(-i\\delta t \\sigma^y/2),\\qquad \n", "U_z=\\mathrm{exp}(-i\\delta t \\sigma^z/2),\n", "\$$\n", "\n", "where $\\delta t$ is a fixed time step, $\\mathrm{exp}(\\cdot)$ is the matrix exponential, $\\boldsymbol{1}$ is the identity, and the Pauli matrices are defined as\n", "\n", "\$$\n", "\\boldsymbol{1}=\\begin{pmatrix}\n", "1 & 0 \\\\ 0 & 1\n", "\\end{pmatrix}\n", ",\\qquad\n", "\\sigma^x=\\begin{pmatrix}\n", "0 & 1 \\\\ 1 & 0\n", "\\end{pmatrix}\n", ",\\qquad\n", "\\sigma^y=\\begin{pmatrix}\n", "0 & -i \\\\ i & 0\n", "\\end{pmatrix}\n", ",\\ \\qquad\n", "\\sigma^z=\\begin{pmatrix}\n", "1 & 0 \\\\ 0 & -1\n", "\\end{pmatrix}\n", "\$$\n", "\n", "To determine if a qubit, described by the state $|\\psi\\rangle$, is in a desired target state $|\\psi_\\mathrm{target}\\rangle$, we compute the fidelity\n", "\n", "\\begin{eqnarray}\n", "F=|\\langle\\psi_\\mathrm{target}|\\psi\\rangle|^2 = |(\\psi_\\mathrm{target})^\\ast_1 \\psi_1 + (\\psi_\\mathrm{target})^\\ast_2 \\psi_2|^2,\\qquad F\\in[0,1]\n", "\\end{eqnarray}\n", "\n", "where $\\ast$ stands for complex conjugation. Physically, the fidelity corresponds to the angle between the arrows representing the qubit state on the Bloch sphere (we want to maximize the fidelity but minimize the angle between the states).\n", "\n", "### Constructing the Qubit Environment\n", "\n", "Now, let us define an episodic RL environment, which contains the laws of physics that govern the dynamics of the qubit (i.e. the application of the gate operations to the qubit state). Our RL agent will later interact with this environment to learn how to control the qubit to bring it from an initial state to a prescribed target state. \n", "\n", "We define the RL states $s=(\\theta,\\varphi)$ as an array containing the Bloch sphere angles of the quantum state. Each step within an episode, the agent can choose to apply one out of the actions, corresponding to the four gates $(\\boldsymbol{1},U_x,U_y,U_z)$. We use the instantaneous fidelity w.r.t. the target state as a reward: $r_t=F=|\\langle\\psi_\\ast|\\psi(t)\\rangle|^2$: \n", "\n", "**state space:** $\\mathcal{S} = \\{(\\theta,\\varphi)|\\theta\\in[0,\\pi],\\varphi\\in[0,2\\pi)\\}$. Unlike in Notebook 2, there are no terminal states here. Instead, we consider a fixed number of time steps, after which the episode terminates deterministically. The target state (i.e. the qubit state we want to prepare) is $|\\psi_\\mathrm{target}\\rangle=(1,0)^t$: it has the Bloch sphere coordinates $s_\\mathrm{target}=(0,0)$.\n", "\n", "**action space:** $\\mathcal{A} = \\{\\boldsymbol{1},U_x,U_y,U_z\\}$. Actions act on RL states as follows: \n", "1. if the current state is $s=(\\theta,\\varphi)$, we first create the quantums state $|\\psi(s)\\rangle$; \n", "2. we apply the gate $U_a$ corresponding to action $a$ to the quantum state, and obtain the new quantum state $|\\psi(s')\\rangle = U_a|\\psi(s)\\rangle$. \n", "3. last, we compute the Bloch sphere coordinates which define the next state $s'=(\\theta',\\varphi')$, using the Bloch sphere parametrization for qubits given above.\n", "Note that all actions are allowed from every state. \n", "\n", "\n", "**reward space:** $\\mathcal{R}=[0,1]$. We use the fidelity between the next state $s'$ and the terminal state $s_\\mathrm{target}$ as a reward at every episode step: \n", "\n", "$$r(s,s',a)= F = |\\langle\\psi_\\mathrm{target}|U_a|\\psi(s)\\rangle|^2=|\\langle\\psi_\\mathrm{target}|\\psi(s')\\rangle|^2$$\n", "\n", "for all states $s,s'\\in\\mathcal{S}$ and actions $a\\in\\mathcal{A}$. " ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import import_ipynb\n", "from Notebook_2_RL_environments import QubitEnv2 # import environment, notebooks must be in same directory" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# set seed of rng (for reproducibility of the results)\n", "n_time_steps = 60\n", "seed=0 \n", "np.random.seed(seed)\n", "\n", "# create environment class\n", "env=QubitEnv2(n_time_steps, seed=seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Policy Gradient Implementation\n", "\n", "The implementation of PG follows similar steps as the MNIST problem from Notebook 7:\n", "\n", "1. Define the a SoftMax model for the discrete policy $\\pi_\\theta$.\n", "2. Define the pseudo loss function to easily compute $\\nabla_\\theta J(\\theta)$.\n", "3. Define generalized gradient descent optimizer.\n", "4. Define the PG training loop and train the policy.\n", "\n", "\n", "### Define the a SoftMax model for the discrete policy $\\pi_\\theta$\n", "\n", "Use JAX to construct a feed-forward fully-connected deep neural network with neuron acrchitecture $(M_s, 512, 256, |\\mathcal{A}|)$, where there are $512$ ($256$) neurons in the first (second) hidden layer, respectively, and $M_s$ and $|\\mathcal{A}|$ define the input and output sizes.\n", "\n", "The input data into the neural network should have the shape input_shape = (-1, n_time_steps, M_s), where M_s is the number of features/components in the RL state $s=(\\theta,\\varphi)$. The output data should have the shape output_shape = (-1, n_time_steps, abs_A), where abs_A$=|\\mathcal{A}|$. In this way, we can use the neural network to process simultaneously all time steps and MC samples, generated in a single training iteration. \n", "\n", "Check explicitly the output shape and test that the network runs on some fake data (e.g. a small batch of vectors of ones with the appropriate shape). " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "output shape of the policy network is (-1, 60, 4).\n", "\n", "(3, 60, 4)\n" ] } ], "source": [ "import jax.numpy as jnp # jax's numpy version with GPU support\n", "from jax import random # used to define a RNG key to control the random input in JAX\n", "from jax.experimental import stax # neural network library\n", "from jax.experimental.stax import Dense, Relu, LogSoftmax # neural network layers\n", "\n", "# set key for the RNG (see JAX docs)\n", "rng = random.PRNGKey(seed)\n", "\n", "# define functions which initialize the parameters and evaluate the model\n", "initialize_params, predict = stax.serial(\n", " ### fully connected DNN\n", " Dense(512), # 512 hidden neurons\n", " Relu,\n", " Dense(256), # 256 hidden neurons\n", " Relu,\n", " Dense(env.n_actions), # 4 output neurons\n", " LogSoftmax # NB: computes the log-probability\n", " )\n", "\n", "# initialize the model parameters\n", "input_shape = (-1,env.n_time_steps,2) # -1: number of MC points, number of time steps, size of state vector\n", "output_shape, inital_params = initialize_params(rng, input_shape) # fcc layer 28x28 pixes in each image\n", "\n", "print('\\noutput shape of the policy network is {}.\\n'.format(output_shape))\n", "\n", "\n", "# test network\n", "states=np.ones((3,env.n_time_steps,2), dtype=np.float32)\n", "\n", "predictions = predict(inital_params, states)\n", "# check the output shape\n", "print(predictions.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the pseudo loss function to easily compute $\\nabla_\\theta J(\\theta)$\n", "\n", "In class we can defined a scalar pseudoloss function, whose gradients give $\\nabla_\\theta J(\\theta)$. Note that this pseudoloss does ***NOT*** correspond to the RL objective $J(\\theta)$: the difference stems from the fact that the two operations of taking the derivative and performing the MC approximation are not interchangeable (do you see why?). \n", "\n", "$$\n", "J_\\mathrm{pseudo}(\\theta) = \n", "\\frac{1}{N}\\sum_{j=1}^N \\sum_{t=1}^T \\log \\pi_\\theta(a^j_t|s^j_t) \\left[\\sum_{t'=t}^T r(a^j_{t'}|s^j_{t'})) - b_t\\right],\\qquad \n", "b_t = \\frac{1}{N}\\sum_{j=1}^N G_t(\\tau_j).\n", "$$\n", "The baseline is a sample average of the reward-to-go (return) from time step $t$ onwards: $G_t(\\tau_j) = \\sum_{t'=t}^N r(s^j_{t'},s^j_{t'})$ .\n", "\n", "Because we will be doing gradient **a**scent, do **NOT** forget to add an extra minus sign to the output ot the pseudoloss (or else your agent will end up minimizing the return). \n", "\n", "Below, we also add an L2 regularizer to the pseudoloss function to prevent overfitting. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "### define loss and accuracy functions\n", "\n", "from jax import grad\n", "from jax.tree_util import tree_flatten # jax params are stored as nested tuples; use this to manipulate tuples\n", "\n", "\n", "def l2_regularizer(params, lmbda):\n", " \"\"\"\n", " Define l2 regularizer: $\\lambda \\ sum_j ||theta_j||^2$ for every parameter in the model $\\theta_j$\n", " \n", " \"\"\"\n", " return lmbda*jnp.sum(jnp.array([jnp.sum(jnp.abs(theta)**2) for theta in tree_flatten(params)[0] ]))\n", "\n", "\n", "def pseudo_loss(params, trajectory_batch):\n", " \"\"\"\n", " Define the pseudo loss function for policy gradient. \n", " \n", " params: object(jax pytree):\n", " parameters of the deep policy network.\n", " trajectory_batch: tuple (states, actions, returns) containing the RL states, actions and returns (not the rewards!): \n", " states: np.array of size (N_MC, env.n_time_steps,2)\n", " actions: np.array of size (N_MC, env.n_time_steps)\n", " returns: np.array of size (N_MC, env.n_time_steps)\n", " \n", " Returns:\n", " -J_{pseudo}(\\theta)\n", "\n", " \"\"\"\n", " # extract data from the batch\n", " states, actions, returns = trajectory_batch\n", " # compute policy predictions\n", " preds = predict(params, states)\n", " # combute the baseline\n", " baseline = jnp.mean(rewards, axis=0)\n", " # select those values of the policy along the action trajectory\n", " preds_select = jnp.take_along_axis(preds, jnp.expand_dims(actions, axis=2), axis=2).squeeze()\n", " # return negative pseudo loss function (want to maximize reward with gradient DEscent)\n", " return -jnp.mean(jnp.sum(preds_select * (returns - baseline) )) + l2_regularizer(params, 0.001)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define generalized gradient descent optimizer\n", "\n", "Define the optimizer and the update function which computes the gradient o the pseudo-loss function and performs the update. \n", "\n", "We use the Adam optimizer here with step_size = 0.001 and the rest of the parameters have default values. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "### define generalized gradient descent optimizer and a function to update model parameters\n", "\n", "from jax.experimental import optimizers # gradient descent optimizers\n", "from jax import jit\n", "\n", "step_size = 0.001 # step size or learning rate \n", "\n", "# compute optimizer functions\n", "opt_init, opt_update, get_params = optimizers.adam(step_size)\n", "\n", "\n", "# define function which updates the parameters using the change computed by the optimizer\n", "@jit # Just In Time compilation speeds up the code; requires to use jnp everywhere; remove when debugging\n", "def update(i, opt_state, batch):\n", " \"\"\"\n", " i: int,\n", " counter to count how many update steps we have performed\n", " opt_state: object,\n", " the state of the optimizer\n", " batch: np.array\n", " batch containing the data used to update the model\n", " \n", " Returns: \n", " opt_state: object,\n", " the new state of the optimizer\n", " \n", " \"\"\"\n", " # get current parameters of the model\n", " current_params = get_params(opt_state)\n", " # compute gradients\n", " grad_params = grad(pseudo_loss)(current_params, batch)\n", " # use the optimizer to perform the update using opt_update\n", " return opt_update(i, grad_params, opt_state)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define the PG training loop and train the policy\n", "\n", "Finally, we implement the REINFORCE algorithm for policy gradient. Follow the steps below\n", "\n", "1. Preallocate variables\n", " * Define the number of episodes N_episodes, and the batch size N_MC.\n", " * Preallocate arrays for the current state, and the states, actions, returns triple which defines the trajectory batch. \n", " * Preallocate arrays to compute the mean_final_reward, std_final_reward, min_final_reward, and , max_final_reward.\n", "2. Initialize the optimizer using the opt_init function. \n", "3. Loop over the episodes; for every episode:\n", " \n", " 3.1 get the current Network parameters\n", " \n", " 3.2 loop to collect MC samples\n", " \n", " 3.2.1 reset the env and roll out the policy until the episode is over; collect the trajectory data\n", " 3.2.2 compute the returns (rewards to go)\n", " \n", " 3.3 compile the PG data into a trajctory batch\n", " \n", " 3.4 use the update function to update the network parameters\n", " \n", " 3.5 print instantaneous performance" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Starting training...\n", "\n", "episode 0 in 8.90 sec\n", "mean reward: 0.4562\n", "return standard deviation: 0.2593\n", "min return: 0.0051; max return: 0.9465\n", "\n", "episode 1 in 9.14 sec\n", "mean reward: 0.4051\n", "return standard deviation: 0.2847\n", "min return: 0.0071; max return: 0.9868\n", "\n", "episode 2 in 8.67 sec\n", "mean reward: 0.4554\n", "return standard deviation: 0.2654\n", "min return: 0.0070; max return: 0.9843\n", "\n", "episode 3 in 8.77 sec\n", "mean reward: 0.5347\n", "return standard deviation: 0.2965\n", "min return: 0.0125; max return: 0.9942\n", "\n", "episode 4 in 9.10 sec\n", "mean reward: 0.4993\n", "return standard deviation: 0.2915\n", "min return: 0.0036; max return: 0.9948\n", "\n", "episode 5 in 9.29 sec\n", "mean reward: 0.5093\n", "return standard deviation: 0.2557\n", "min return: 0.0265; max return: 0.9842\n", "\n", "episode 6 in 8.86 sec\n", "mean reward: 0.5573\n", "return standard deviation: 0.2985\n", "min return: 0.0192; max return: 0.9903\n", "\n", "episode 7 in 9.31 sec\n", "mean reward: 0.4865\n", "return standard deviation: 0.2842\n", "min return: 0.0125; max return: 0.9935\n", "\n", "episode 8 in 9.03 sec\n", "mean reward: 0.5378\n", "return standard deviation: 0.2851\n", "min return: 0.0060; max return: 0.9813\n", "\n", "episode 9 in 9.46 sec\n", "mean reward: 0.5306\n", "return standard deviation: 0.3091\n", "min return: 0.0541; max return: 0.9978\n", "\n", "episode 10 in 9.03 sec\n", "mean reward: 0.5294\n", "return standard deviation: 0.2730\n", "min return: 0.0009; max return: 0.9724\n", "\n", "episode 11 in 9.41 sec\n", "mean reward: 0.5583\n", "return standard deviation: 0.2846\n", "min return: 0.0234; max return: 0.9857\n", "\n", "episode 12 in 8.96 sec\n", "mean reward: 0.6407\n", "return standard deviation: 0.2760\n", "min return: 0.0545; max return: 0.9864\n", "\n", "episode 13 in 9.20 sec\n", "mean reward: 0.6151\n", "return standard deviation: 0.2683\n", "min return: 0.0156; max return: 0.9967\n", "\n", "episode 14 in 9.36 sec\n", "mean reward: 0.5790\n", "return standard deviation: 0.2913\n", "min return: 0.0216; max return: 0.9975\n", "\n", "episode 15 in 8.87 sec\n", "mean reward: 0.5868\n", "return standard deviation: 0.2821\n", "min return: 0.0052; max return: 0.9910\n", "\n", "episode 16 in 8.99 sec\n", "mean reward: 0.7151\n", "return standard deviation: 0.2415\n", "min return: 0.1630; max return: 0.9982\n", "\n", "episode 17 in 9.41 sec\n", "mean reward: 0.6150\n", "return standard deviation: 0.2886\n", "min return: 0.0168; max return: 0.9989\n", "\n", "episode 18 in 9.68 sec\n", "mean reward: 0.6341\n", "return standard deviation: 0.2463\n", "min return: 0.0259; max return: 0.9963\n", "\n", "episode 19 in 9.30 sec\n", "mean reward: 0.6562\n", "return standard deviation: 0.2485\n", "min return: 0.0988; max return: 0.9996\n", "\n", "episode 20 in 9.23 sec\n", "mean reward: 0.6304\n", "return standard deviation: 0.2602\n", "min return: 0.0964; max return: 0.9991\n", "\n", "episode 21 in 9.55 sec\n", "mean reward: 0.7116\n", "return standard deviation: 0.2353\n", "min return: 0.0394; max return: 0.9997\n", "\n", "episode 22 in 9.52 sec\n", "mean reward: 0.6820\n", "return standard deviation: 0.2720\n", "min return: 0.0979; max return: 0.9998\n", "\n", "episode 23 in 9.68 sec\n", "mean reward: 0.6908\n", "return standard deviation: 0.2357\n", "min return: 0.1425; max return: 0.9970\n", "\n", "episode 24 in 11.33 sec\n", "mean reward: 0.7206\n", "return standard deviation: 0.2125\n", "min return: 0.0277; max return: 0.9913\n", "\n", "episode 25 in 9.63 sec\n", "mean reward: 0.7112\n", "return standard deviation: 0.2254\n", "min return: 0.0384; max return: 0.9940\n", "\n", "episode 26 in 9.24 sec\n", "mean reward: 0.7421\n", "return standard deviation: 0.2210\n", "min return: 0.0871; max return: 0.9838\n", "\n", "episode 27 in 9.04 sec\n", "mean reward: 0.7187\n", "return standard deviation: 0.2057\n", "min return: 0.2195; max return: 0.9878\n", "\n", "episode 28 in 9.19 sec\n", "mean reward: 0.7141\n", "return standard deviation: 0.2373\n", "min return: 0.2032; max return: 0.9948\n", "\n", "episode 29 in 8.96 sec\n", "mean reward: 0.7788\n", "return standard deviation: 0.1877\n", "min return: 0.2063; max return: 0.9973\n", "\n", "episode 30 in 9.16 sec\n", "mean reward: 0.7236\n", "return standard deviation: 0.2281\n", "min return: 0.0355; max return: 0.9875\n", "\n", "episode 31 in 9.15 sec\n", "mean reward: 0.7176\n", "return standard deviation: 0.1989\n", "min return: 0.2246; max return: 0.9769\n", "\n", "episode 32 in 8.86 sec\n", "mean reward: 0.7703\n", "return standard deviation: 0.1802\n", "min return: 0.3438; max return: 0.9994\n", "\n", "episode 33 in 9.12 sec\n", "mean reward: 0.7312\n", "return standard deviation: 0.2225\n", "min return: 0.1938; max return: 0.9923\n", "\n", "episode 34 in 8.97 sec\n", "mean reward: 0.7701\n", "return standard deviation: 0.1705\n", "min return: 0.3199; max return: 0.9985\n", "\n", "episode 35 in 9.50 sec\n", "mean reward: 0.7752\n", "return standard deviation: 0.1667\n", "min return: 0.3696; max return: 0.9938\n", "\n", "episode 36 in 9.07 sec\n", "mean reward: 0.7569\n", "return standard deviation: 0.1618\n", "min return: 0.2729; max return: 0.9893\n", "\n", "episode 37 in 9.03 sec\n", "mean reward: 0.7762\n", "return standard deviation: 0.1813\n", "min return: 0.2023; max return: 0.9927\n", "\n", "episode 38 in 9.56 sec\n", "mean reward: 0.7589\n", "return standard deviation: 0.1825\n", "min return: 0.1899; max return: 0.9992\n", "\n", "episode 39 in 8.75 sec\n", "mean reward: 0.8207\n", "return standard deviation: 0.1517\n", "min return: 0.1449; max return: 0.9992\n", "\n", "episode 40 in 8.80 sec\n", "mean reward: 0.7550\n", "return standard deviation: 0.1653\n", "min return: 0.2631; max return: 0.9950\n", "\n", "episode 41 in 8.78 sec\n", "mean reward: 0.7910\n", "return standard deviation: 0.1739\n", "min return: 0.1389; max return: 0.9999\n", "\n", "episode 42 in 8.79 sec\n", "mean reward: 0.7976\n", "return standard deviation: 0.1599\n", "min return: 0.2042; max return: 0.9865\n", "\n", "episode 43 in 8.83 sec\n", "mean reward: 0.8185\n", "return standard deviation: 0.1551\n", "min return: 0.2409; max return: 0.9972\n", "\n", "episode 44 in 8.75 sec\n", "mean reward: 0.7915\n", "return standard deviation: 0.2075\n", "min return: 0.1613; max return: 0.9962\n", "\n", "episode 45 in 8.79 sec\n", "mean reward: 0.7909\n", "return standard deviation: 0.2015\n", "min return: 0.1553; max return: 0.9999\n", "\n", "episode 46 in 9.26 sec\n", "mean reward: 0.8140\n", "return standard deviation: 0.1590\n", "min return: 0.2167; max return: 0.9990\n", "\n", "episode 47 in 9.72 sec\n", "mean reward: 0.8372\n", "return standard deviation: 0.1374\n", "min return: 0.2924; max return: 0.9997\n", "\n", "episode 48 in 9.65 sec\n", "mean reward: 0.8210\n", "return standard deviation: 0.1420\n", "min return: 0.2590; max return: 0.9910\n", "\n", "episode 49 in 9.23 sec\n", "mean reward: 0.8402\n", "return standard deviation: 0.1232\n", "min return: 0.5231; max return: 0.9986\n", "\n", "episode 50 in 9.14 sec\n", "mean reward: 0.8365\n", "return standard deviation: 0.1377\n", "min return: 0.4923; max return: 0.9972\n", "\n", "episode 51 in 8.36 sec\n", "mean reward: 0.8628\n", "return standard deviation: 0.1150\n", "min return: 0.5252; max return: 0.9988\n", "\n", "episode 52 in 8.26 sec\n", "mean reward: 0.8809\n", "return standard deviation: 0.0942\n", "min return: 0.5694; max return: 0.9916\n", "\n", "episode 53 in 8.20 sec\n", "mean reward: 0.8907\n", "return standard deviation: 0.0919\n", "min return: 0.5806; max return: 0.9984\n", "\n", "episode 54 in 8.24 sec\n", "mean reward: 0.8687\n", "return standard deviation: 0.1057\n", "min return: 0.5968; max return: 0.9998\n", "\n", "episode 55 in 8.21 sec\n", "mean reward: 0.8456\n", "return standard deviation: 0.1166\n", "min return: 0.5156; max return: 0.9999\n", "\n", "episode 56 in 8.20 sec\n", "mean reward: 0.8457\n", "return standard deviation: 0.1084\n", "min return: 0.5797; max return: 0.9954\n", "\n", "episode 57 in 8.18 sec\n", "mean reward: 0.8819\n", "return standard deviation: 0.0976\n", "min return: 0.5934; max return: 0.9997\n", "\n", "episode 58 in 8.20 sec\n", "mean reward: 0.8878\n", "return standard deviation: 0.0924\n", "min return: 0.6578; max return: 0.9996\n", "\n", "episode 59 in 8.20 sec\n", "mean reward: 0.8851\n", "return standard deviation: 0.1024\n", "min return: 0.6078; max return: 0.9981\n", "\n", "episode 60 in 8.32 sec\n", "mean reward: 0.8834\n", "return standard deviation: 0.1038\n", "min return: 0.4849; max return: 0.9997\n", "\n", "episode 61 in 8.21 sec\n", "mean reward: 0.8864\n", "return standard deviation: 0.0946\n", "min return: 0.5364; max return: 0.9995\n", "\n", "episode 62 in 8.19 sec\n", "mean reward: 0.8756\n", "return standard deviation: 0.0889\n", "min return: 0.6705; max return: 0.9987\n", "\n", "episode 63 in 8.20 sec\n", "mean reward: 0.8888\n", "return standard deviation: 0.1026\n", "min return: 0.6283; max return: 0.9999\n", "\n", "episode 64 in 8.20 sec\n", "mean reward: 0.9136\n", "return standard deviation: 0.0786\n", "min return: 0.5912; max return: 0.9992\n", "\n", "episode 65 in 8.36 sec\n", "mean reward: 0.8897\n", "return standard deviation: 0.0774\n", "min return: 0.6997; max return: 0.9962\n", "\n", "episode 66 in 8.22 sec\n", "mean reward: 0.9129\n", "return standard deviation: 0.0653\n", "min return: 0.7374; max return: 0.9991\n", "\n", "episode 67 in 8.24 sec\n", "mean reward: 0.9091\n", "return standard deviation: 0.0766\n", "min return: 0.5948; max return: 0.9991\n", "\n", "episode 68 in 8.21 sec\n", "mean reward: 0.9218\n", "return standard deviation: 0.0653\n", "min return: 0.6746; max return: 0.9991\n", "\n", "episode 69 in 8.22 sec\n", "mean reward: 0.9284\n", "return standard deviation: 0.0715\n", "min return: 0.6303; max return: 0.9999\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "episode 70 in 8.22 sec\n", "mean reward: 0.9421\n", "return standard deviation: 0.0653\n", "min return: 0.6942; max return: 0.9998\n", "\n", "episode 71 in 8.25 sec\n", "mean reward: 0.9335\n", "return standard deviation: 0.0581\n", "min return: 0.7803; max return: 0.9989\n", "\n", "episode 72 in 8.29 sec\n", "mean reward: 0.9331\n", "return standard deviation: 0.0591\n", "min return: 0.7460; max return: 0.9996\n", "\n", "episode 73 in 8.22 sec\n", "mean reward: 0.9406\n", "return standard deviation: 0.0512\n", "min return: 0.7864; max return: 0.9997\n", "\n", "episode 74 in 8.28 sec\n", "mean reward: 0.9360\n", "return standard deviation: 0.0658\n", "min return: 0.6442; max return: 0.9969\n", "\n", "episode 75 in 8.21 sec\n", "mean reward: 0.9306\n", "return standard deviation: 0.1241\n", "min return: 0.0649; max return: 0.9997\n", "\n", "episode 76 in 8.20 sec\n", "mean reward: 0.9414\n", "return standard deviation: 0.0589\n", "min return: 0.7616; max return: 0.9992\n", "\n", "episode 77 in 8.17 sec\n", "mean reward: 0.9313\n", "return standard deviation: 0.0571\n", "min return: 0.7284; max return: 0.9997\n", "\n", "episode 78 in 8.16 sec\n", "mean reward: 0.9351\n", "return standard deviation: 0.0800\n", "min return: 0.4576; max return: 0.9999\n", "\n", "episode 79 in 8.20 sec\n", "mean reward: 0.9483\n", "return standard deviation: 0.0439\n", "min return: 0.8036; max return: 0.9996\n", "\n", "episode 80 in 8.20 sec\n", "mean reward: 0.9413\n", "return standard deviation: 0.0550\n", "min return: 0.6899; max return: 0.9982\n", "\n", "episode 81 in 8.19 sec\n", "mean reward: 0.9439\n", "return standard deviation: 0.0487\n", "min return: 0.7783; max return: 0.9989\n", "\n", "episode 82 in 8.17 sec\n", "mean reward: 0.9374\n", "return standard deviation: 0.0513\n", "min return: 0.8131; max return: 0.9997\n", "\n", "episode 83 in 8.16 sec\n", "mean reward: 0.9390\n", "return standard deviation: 0.0580\n", "min return: 0.7053; max return: 0.9979\n", "\n", "episode 84 in 8.15 sec\n", "mean reward: 0.9378\n", "return standard deviation: 0.0516\n", "min return: 0.7748; max return: 0.9993\n", "\n", "episode 85 in 8.20 sec\n", "mean reward: 0.9363\n", "return standard deviation: 0.0452\n", "min return: 0.7753; max return: 0.9994\n", "\n", "episode 86 in 8.13 sec\n", "mean reward: 0.9408\n", "return standard deviation: 0.0506\n", "min return: 0.7392; max return: 0.9998\n", "\n", "episode 87 in 8.17 sec\n", "mean reward: 0.9413\n", "return standard deviation: 0.0510\n", "min return: 0.8213; max return: 1.0000\n", "\n", "episode 88 in 8.14 sec\n", "mean reward: 0.9276\n", "return standard deviation: 0.0665\n", "min return: 0.7419; max return: 0.9997\n", "\n", "episode 89 in 8.17 sec\n", "mean reward: 0.9267\n", "return standard deviation: 0.0626\n", "min return: 0.7629; max return: 0.9995\n", "\n", "episode 90 in 8.16 sec\n", "mean reward: 0.9204\n", "return standard deviation: 0.1075\n", "min return: 0.2549; max return: 0.9995\n", "\n", "episode 91 in 8.29 sec\n", "mean reward: 0.9205\n", "return standard deviation: 0.0975\n", "min return: 0.3478; max return: 0.9961\n", "\n", "episode 92 in 8.14 sec\n", "mean reward: 0.9312\n", "return standard deviation: 0.0591\n", "min return: 0.7153; max return: 0.9989\n", "\n", "episode 93 in 8.14 sec\n", "mean reward: 0.9343\n", "return standard deviation: 0.0664\n", "min return: 0.7087; max return: 0.9991\n", "\n", "episode 94 in 8.22 sec\n", "mean reward: 0.9189\n", "return standard deviation: 0.1261\n", "min return: 0.2715; max return: 0.9995\n", "\n", "episode 95 in 8.18 sec\n", "mean reward: 0.8925\n", "return standard deviation: 0.1021\n", "min return: 0.5659; max return: 0.9997\n", "\n", "episode 96 in 8.13 sec\n", "mean reward: 0.8824\n", "return standard deviation: 0.1471\n", "min return: 0.3693; max return: 0.9996\n", "\n", "episode 97 in 8.18 sec\n", "mean reward: 0.8992\n", "return standard deviation: 0.1243\n", "min return: 0.2676; max return: 0.9989\n", "\n", "episode 98 in 8.14 sec\n", "mean reward: 0.8647\n", "return standard deviation: 0.1670\n", "min return: 0.2451; max return: 0.9995\n", "\n", "episode 99 in 8.14 sec\n", "mean reward: 0.9135\n", "return standard deviation: 0.1198\n", "min return: 0.4102; max return: 0.9997\n", "\n", "episode 100 in 8.15 sec\n", "mean reward: 0.8907\n", "return standard deviation: 0.1430\n", "min return: 0.3313; max return: 0.9998\n", "\n", "episode 101 in 8.16 sec\n", "mean reward: 0.9076\n", "return standard deviation: 0.1043\n", "min return: 0.4901; max return: 0.9993\n", "\n", "episode 102 in 8.15 sec\n", "mean reward: 0.9323\n", "return standard deviation: 0.0698\n", "min return: 0.6613; max return: 0.9984\n", "\n", "episode 103 in 8.17 sec\n", "mean reward: 0.9249\n", "return standard deviation: 0.0901\n", "min return: 0.3862; max return: 0.9999\n", "\n", "episode 104 in 8.13 sec\n", "mean reward: 0.9185\n", "return standard deviation: 0.0726\n", "min return: 0.7143; max return: 0.9975\n", "\n", "episode 105 in 8.14 sec\n", "mean reward: 0.9379\n", "return standard deviation: 0.1014\n", "min return: 0.2571; max return: 0.9998\n", "\n", "episode 106 in 8.16 sec\n", "mean reward: 0.9465\n", "return standard deviation: 0.0448\n", "min return: 0.8116; max return: 0.9991\n", "\n", "episode 107 in 8.14 sec\n", "mean reward: 0.9408\n", "return standard deviation: 0.0590\n", "min return: 0.7569; max return: 0.9990\n", "\n", "episode 108 in 8.17 sec\n", "mean reward: 0.9483\n", "return standard deviation: 0.0427\n", "min return: 0.8300; max return: 1.0000\n", "\n", "episode 109 in 8.21 sec\n", "mean reward: 0.9419\n", "return standard deviation: 0.0681\n", "min return: 0.4777; max return: 0.9979\n", "\n", "episode 110 in 8.19 sec\n", "mean reward: 0.9614\n", "return standard deviation: 0.0410\n", "min return: 0.8126; max return: 0.9999\n", "\n", "episode 111 in 8.16 sec\n", "mean reward: 0.9611\n", "return standard deviation: 0.0412\n", "min return: 0.8131; max return: 0.9991\n", "\n", "episode 112 in 8.18 sec\n", "mean reward: 0.9590\n", "return standard deviation: 0.0425\n", "min return: 0.8133; max return: 0.9999\n", "\n", "episode 113 in 8.19 sec\n", "mean reward: 0.9672\n", "return standard deviation: 0.0345\n", "min return: 0.8631; max return: 0.9998\n", "\n", "episode 114 in 8.17 sec\n", "mean reward: 0.9613\n", "return standard deviation: 0.0362\n", "min return: 0.8465; max return: 0.9993\n", "\n", "episode 115 in 8.16 sec\n", "mean reward: 0.9525\n", "return standard deviation: 0.0498\n", "min return: 0.8003; max return: 0.9998\n", "\n", "episode 116 in 8.17 sec\n", "mean reward: 0.9564\n", "return standard deviation: 0.0440\n", "min return: 0.7952; max return: 0.9996\n", "\n", "episode 117 in 8.17 sec\n", "mean reward: 0.9609\n", "return standard deviation: 0.0384\n", "min return: 0.8364; max return: 0.9998\n", "\n", "episode 118 in 8.15 sec\n", "mean reward: 0.9631\n", "return standard deviation: 0.0474\n", "min return: 0.7872; max return: 0.9998\n", "\n", "episode 119 in 8.15 sec\n", "mean reward: 0.9661\n", "return standard deviation: 0.0302\n", "min return: 0.8693; max return: 0.9993\n", "\n", "episode 120 in 8.23 sec\n", "mean reward: 0.9635\n", "return standard deviation: 0.0347\n", "min return: 0.8605; max return: 0.9999\n", "\n", "episode 121 in 8.16 sec\n", "mean reward: 0.9666\n", "return standard deviation: 0.0334\n", "min return: 0.8362; max return: 0.9990\n", "\n", "episode 122 in 8.16 sec\n", "mean reward: 0.9650\n", "return standard deviation: 0.0318\n", "min return: 0.8783; max return: 0.9997\n", "\n", "episode 123 in 8.17 sec\n", "mean reward: 0.9689\n", "return standard deviation: 0.0277\n", "min return: 0.8868; max return: 0.9998\n", "\n", "episode 124 in 8.21 sec\n", "mean reward: 0.9685\n", "return standard deviation: 0.0282\n", "min return: 0.8813; max return: 0.9999\n", "\n", "episode 125 in 8.19 sec\n", "mean reward: 0.9623\n", "return standard deviation: 0.0416\n", "min return: 0.8181; max return: 0.9999\n", "\n", "episode 126 in 8.16 sec\n", "mean reward: 0.9767\n", "return standard deviation: 0.0244\n", "min return: 0.8786; max return: 0.9985\n", "\n", "episode 127 in 8.18 sec\n", "mean reward: 0.9755\n", "return standard deviation: 0.0247\n", "min return: 0.8597; max return: 0.9997\n", "\n", "episode 128 in 8.15 sec\n", "mean reward: 0.9712\n", "return standard deviation: 0.0274\n", "min return: 0.8736; max return: 0.9999\n", "\n", "episode 129 in 8.16 sec\n", "mean reward: 0.9723\n", "return standard deviation: 0.0294\n", "min return: 0.8423; max return: 0.9996\n", "\n", "episode 130 in 8.13 sec\n", "mean reward: 0.9757\n", "return standard deviation: 0.0273\n", "min return: 0.8886; max return: 0.9996\n", "\n", "episode 131 in 8.17 sec\n", "mean reward: 0.9706\n", "return standard deviation: 0.0350\n", "min return: 0.8081; max return: 1.0000\n", "\n", "episode 132 in 8.14 sec\n", "mean reward: 0.9736\n", "return standard deviation: 0.0254\n", "min return: 0.8619; max return: 0.9998\n", "\n", "episode 133 in 8.17 sec\n", "mean reward: 0.9690\n", "return standard deviation: 0.0399\n", "min return: 0.8256; max return: 1.0000\n", "\n", "episode 134 in 8.17 sec\n", "mean reward: 0.9744\n", "return standard deviation: 0.0267\n", "min return: 0.8564; max return: 0.9996\n", "\n", "episode 135 in 8.13 sec\n", "mean reward: 0.9696\n", "return standard deviation: 0.0338\n", "min return: 0.8445; max return: 0.9999\n", "\n", "episode 136 in 8.16 sec\n", "mean reward: 0.9718\n", "return standard deviation: 0.0385\n", "min return: 0.7722; max return: 0.9997\n", "\n", "episode 137 in 8.18 sec\n", "mean reward: 0.9750\n", "return standard deviation: 0.0268\n", "min return: 0.8717; max return: 0.9993\n", "\n", "episode 138 in 8.28 sec\n", "mean reward: 0.9754\n", "return standard deviation: 0.0258\n", "min return: 0.8755; max return: 0.9997\n", "\n", "episode 139 in 8.17 sec\n", "mean reward: 0.9771\n", "return standard deviation: 0.0261\n", "min return: 0.8484; max return: 0.9997\n", "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "episode 140 in 8.16 sec\n", "mean reward: 0.9729\n", "return standard deviation: 0.0268\n", "min return: 0.8683; max return: 1.0000\n", "\n", "episode 141 in 8.16 sec\n", "mean reward: 0.9796\n", "return standard deviation: 0.0192\n", "min return: 0.9099; max return: 0.9999\n", "\n", "episode 142 in 8.15 sec\n", "mean reward: 0.9816\n", "return standard deviation: 0.0174\n", "min return: 0.9369; max return: 1.0000\n", "\n", "episode 143 in 8.19 sec\n", "mean reward: 0.9739\n", "return standard deviation: 0.0353\n", "min return: 0.7853; max return: 1.0000\n", "\n", "episode 144 in 8.15 sec\n", "mean reward: 0.9788\n", "return standard deviation: 0.0255\n", "min return: 0.8590; max return: 0.9996\n", "\n", "episode 145 in 8.21 sec\n", "mean reward: 0.9688\n", "return standard deviation: 0.0327\n", "min return: 0.8534; max return: 0.9994\n", "\n", "episode 146 in 8.15 sec\n", "mean reward: 0.9747\n", "return standard deviation: 0.0378\n", "min return: 0.8046; max return: 0.9998\n", "\n", "episode 147 in 8.17 sec\n", "mean reward: 0.9778\n", "return standard deviation: 0.0311\n", "min return: 0.8614; max return: 0.9995\n", "\n", "episode 148 in 8.26 sec\n", "mean reward: 0.9813\n", "return standard deviation: 0.0194\n", "min return: 0.9101; max return: 0.9996\n", "\n", "episode 149 in 8.15 sec\n", "mean reward: 0.9755\n", "return standard deviation: 0.0321\n", "min return: 0.8205; max return: 0.9999\n", "\n", "episode 150 in 8.13 sec\n", "mean reward: 0.9775\n", "return standard deviation: 0.0290\n", "min return: 0.8489; max return: 0.9994\n", "\n", "episode 151 in 8.17 sec\n", "mean reward: 0.9790\n", "return standard deviation: 0.0251\n", "min return: 0.8608; max return: 0.9998\n", "\n", "episode 152 in 8.12 sec\n", "mean reward: 0.9820\n", "return standard deviation: 0.0201\n", "min return: 0.9171; max return: 1.0000\n", "\n", "episode 153 in 8.25 sec\n", "mean reward: 0.9851\n", "return standard deviation: 0.0158\n", "min return: 0.9317; max return: 0.9999\n", "\n", "episode 154 in 8.18 sec\n", "mean reward: 0.9770\n", "return standard deviation: 0.0280\n", "min return: 0.8563; max return: 0.9996\n", "\n", "episode 155 in 8.16 sec\n", "mean reward: 0.9808\n", "return standard deviation: 0.0204\n", "min return: 0.8870; max return: 0.9998\n", "\n", "episode 156 in 8.15 sec\n", "mean reward: 0.9844\n", "return standard deviation: 0.0164\n", "min return: 0.9389; max return: 1.0000\n", "\n", "episode 157 in 8.13 sec\n", "mean reward: 0.9798\n", "return standard deviation: 0.0264\n", "min return: 0.8693; max return: 0.9999\n", "\n", "episode 158 in 8.16 sec\n", "mean reward: 0.9783\n", "return standard deviation: 0.0234\n", "min return: 0.8940; max return: 0.9998\n", "\n", "episode 159 in 8.14 sec\n", "mean reward: 0.9814\n", "return standard deviation: 0.0238\n", "min return: 0.8897; max return: 0.9999\n", "\n", "episode 160 in 8.17 sec\n", "mean reward: 0.9800\n", "return standard deviation: 0.0224\n", "min return: 0.8786; max return: 0.9991\n", "\n", "episode 161 in 8.14 sec\n", "mean reward: 0.9790\n", "return standard deviation: 0.0211\n", "min return: 0.9156; max return: 0.9999\n", "\n", "episode 162 in 8.15 sec\n", "mean reward: 0.9778\n", "return standard deviation: 0.0257\n", "min return: 0.8528; max return: 0.9996\n", "\n", "episode 163 in 8.17 sec\n", "mean reward: 0.9633\n", "return standard deviation: 0.0779\n", "min return: 0.4841; max return: 0.9996\n", "\n", "episode 164 in 8.14 sec\n", "mean reward: 0.9678\n", "return standard deviation: 0.0782\n", "min return: 0.3886; max return: 1.0000\n", "\n", "episode 165 in 8.14 sec\n", "mean reward: 0.9769\n", "return standard deviation: 0.0449\n", "min return: 0.6764; max return: 0.9998\n", "\n", "episode 166 in 8.12 sec\n", "mean reward: 0.9745\n", "return standard deviation: 0.0553\n", "min return: 0.5742; max return: 1.0000\n", "\n", "episode 167 in 8.22 sec\n", "mean reward: 0.9651\n", "return standard deviation: 0.0811\n", "min return: 0.3754; max return: 0.9999\n", "\n", "episode 168 in 8.15 sec\n", "mean reward: 0.9653\n", "return standard deviation: 0.0790\n", "min return: 0.4658; max return: 0.9999\n", "\n", "episode 169 in 8.17 sec\n", "mean reward: 0.9793\n", "return standard deviation: 0.0238\n", "min return: 0.9017; max return: 0.9998\n", "\n", "episode 170 in 8.15 sec\n", "mean reward: 0.9607\n", "return standard deviation: 0.0882\n", "min return: 0.4643; max return: 0.9996\n", "\n", "episode 171 in 8.16 sec\n", "mean reward: 0.9719\n", "return standard deviation: 0.0503\n", "min return: 0.6354; max return: 0.9998\n", "\n", "episode 172 in 8.15 sec\n", "mean reward: 0.9753\n", "return standard deviation: 0.0244\n", "min return: 0.8730; max return: 0.9992\n", "\n", "episode 173 in 8.14 sec\n", "mean reward: 0.9751\n", "return standard deviation: 0.0289\n", "min return: 0.8815; max return: 0.9995\n", "\n", "episode 174 in 8.14 sec\n", "mean reward: 0.9746\n", "return standard deviation: 0.0725\n", "min return: 0.4192; max return: 0.9997\n", "\n", "episode 175 in 8.15 sec\n", "mean reward: 0.9752\n", "return standard deviation: 0.0336\n", "min return: 0.8464; max return: 0.9996\n", "\n", "episode 176 in 8.12 sec\n", "mean reward: 0.9769\n", "return standard deviation: 0.0291\n", "min return: 0.8393; max return: 0.9997\n", "\n", "episode 177 in 8.14 sec\n", "mean reward: 0.9810\n", "return standard deviation: 0.0224\n", "min return: 0.8959; max return: 1.0000\n", "\n", "episode 178 in 8.15 sec\n", "mean reward: 0.9842\n", "return standard deviation: 0.0177\n", "min return: 0.9100; max return: 0.9998\n", "\n", "episode 179 in 8.15 sec\n", "mean reward: 0.9833\n", "return standard deviation: 0.0210\n", "min return: 0.9133; max return: 0.9999\n", "\n", "episode 180 in 8.16 sec\n", "mean reward: 0.9871\n", "return standard deviation: 0.0184\n", "min return: 0.8865; max return: 0.9997\n", "\n", "episode 181 in 8.16 sec\n", "mean reward: 0.9891\n", "return standard deviation: 0.0114\n", "min return: 0.9421; max return: 1.0000\n", "\n", "episode 182 in 8.25 sec\n", "mean reward: 0.9845\n", "return standard deviation: 0.0163\n", "min return: 0.9198; max return: 0.9998\n", "\n", "episode 183 in 8.20 sec\n", "mean reward: 0.9901\n", "return standard deviation: 0.0115\n", "min return: 0.9499; max return: 1.0000\n", "\n", "episode 184 in 8.15 sec\n", "mean reward: 0.9832\n", "return standard deviation: 0.0188\n", "min return: 0.9049; max return: 0.9997\n", "\n", "episode 185 in 8.17 sec\n", "mean reward: 0.9858\n", "return standard deviation: 0.0168\n", "min return: 0.9231; max return: 0.9999\n", "\n", "episode 186 in 8.16 sec\n", "mean reward: 0.9849\n", "return standard deviation: 0.0179\n", "min return: 0.9244; max return: 0.9995\n", "\n", "episode 187 in 8.16 sec\n", "mean reward: 0.9906\n", "return standard deviation: 0.0108\n", "min return: 0.9386; max return: 0.9996\n", "\n", "episode 188 in 8.15 sec\n", "mean reward: 0.9878\n", "return standard deviation: 0.0175\n", "min return: 0.8778; max return: 0.9998\n", "\n", "episode 189 in 8.16 sec\n", "mean reward: 0.9866\n", "return standard deviation: 0.0163\n", "min return: 0.9165; max return: 0.9997\n", "\n", "episode 190 in 8.15 sec\n", "mean reward: 0.9884\n", "return standard deviation: 0.0148\n", "min return: 0.9264; max return: 0.9999\n", "\n", "episode 191 in 8.16 sec\n", "mean reward: 0.9861\n", "return standard deviation: 0.0172\n", "min return: 0.9188; max return: 0.9999\n", "\n", "episode 192 in 8.15 sec\n", "mean reward: 0.9863\n", "return standard deviation: 0.0138\n", "min return: 0.9350; max return: 0.9998\n", "\n", "episode 193 in 8.14 sec\n", "mean reward: 0.9887\n", "return standard deviation: 0.0153\n", "min return: 0.9008; max return: 1.0000\n", "\n", "episode 194 in 8.15 sec\n", "mean reward: 0.9891\n", "return standard deviation: 0.0157\n", "min return: 0.9090; max return: 0.9998\n", "\n", "episode 195 in 8.16 sec\n", "mean reward: 0.9887\n", "return standard deviation: 0.0140\n", "min return: 0.9323; max return: 1.0000\n", "\n", "episode 196 in 8.17 sec\n", "mean reward: 0.9880\n", "return standard deviation: 0.0172\n", "min return: 0.9048; max return: 0.9999\n", "\n", "episode 197 in 8.24 sec\n", "mean reward: 0.9873\n", "return standard deviation: 0.0189\n", "min return: 0.8821; max return: 0.9997\n", "\n", "episode 198 in 8.17 sec\n", "mean reward: 0.9903\n", "return standard deviation: 0.0118\n", "min return: 0.9272; max return: 0.9997\n", "\n", "episode 199 in 8.17 sec\n", "mean reward: 0.9881\n", "return standard deviation: 0.0186\n", "min return: 0.8778; max return: 0.9999\n", "\n", "episode 200 in 8.15 sec\n", "mean reward: 0.9904\n", "return standard deviation: 0.0134\n", "min return: 0.9127; max return: 0.9998\n", "\n" ] } ], "source": [ "### Train model\n", "\n", "import time\n", "\n", "# define number of training episodes\n", "N_episodes = 201\n", "N_MC = 64 #128\n", "\n", "\n", "# preallocate data using arrays initialized with zeros\n", "\n", "state=np.zeros((2,), dtype=np.float32)\n", " \n", "states = np.zeros((N_MC, env.n_time_steps,2), dtype=np.float32)\n", "actions = np.zeros((N_MC, env.n_time_steps), dtype=np.int)\n", "returns = np.zeros((N_MC, env.n_time_steps), dtype=np.float32)\n", " \n", "# mean reward at the end of the episode\n", "mean_final_reward = np.zeros(N_episodes, dtype=np.float32)\n", "# standard deviation of the reward at the end of the episode\n", "std_final_reward = np.zeros_like(mean_final_reward)\n", "# batch minimum at the end of the episode\n", "min_final_reward = np.zeros_like(mean_final_reward)\n", "# batch maximum at the end of the episode\n", "max_final_reward = np.zeros_like(mean_final_reward)\n", "\n", "\n", "print(\"\\nStarting training...\\n\")\n", "\n", "# set the initial model parameters in the optimizer\n", "opt_state = opt_init(inital_params)\n", "\n", "# loop over the number of training episodes\n", "for episode in range(N_episodes): \n", " \n", " ### record time\n", " start_time = time.time()\n", " \n", " # get current policy network params\n", " current_params = get_params(opt_state)\n", " \n", " # MC sample\n", " for j in range(N_MC):\n", " \n", " # reset environment to a random initial state\n", " #env.reset(random=False) # fixed initial state\n", " env.reset(random=True) # Haar-random initial state (i.e. uniformly sampled on the sphere)\n", " \n", " # zero rewards array (auxiliary array to store the rewards, and help compute the returns)\n", " rewards = np.zeros((env.n_time_steps, ), dtype=np.float32)\n", " \n", " # loop over steps in an episode\n", " for time_step in range(env.n_time_steps):\n", "\n", " # select state\n", " state[:] = env.state[:]\n", " states[j,time_step,:] = state\n", "\n", " # select an action according to current policy\n", " pi_s = np.exp( predict(current_params, state) )\n", " action = np.random.choice(env.actions, p = pi_s)\n", " actions[j,time_step] = action\n", "\n", " # take an environment step\n", " state[:], reward, _ = env.step(action)\n", "\n", " # store reward\n", " rewards[time_step] = reward\n", " \n", " \n", " # compute reward-to-go \n", " returns[j,:] = jnp.cumsum(rewards[::-1])[::-1]\n", " \n", " \n", " \n", " # define batch of data\n", " trajectory_batch = (states, actions, returns)\n", " \n", " # update model\n", " opt_state = update(episode, opt_state, trajectory_batch)\n", " \n", " ### record time needed for a single epoch\n", " episode_time = time.time() - start_time\n", " \n", " # check performance\n", " mean_final_reward[episode]=jnp.mean(returns[:,-1])\n", " std_final_reward[episode] =jnp.std(returns[:,-1])\n", " min_final_reward[episode], max_final_reward[episode] = np.min(returns[:,-1]), np.max(returns[:,-1])\n", "\n", " \n", " # print results every 10 epochs\n", " #if episode % 5 == 0:\n", " print(\"episode {} in {:0.2f} sec\".format(episode, episode_time))\n", " print(\"mean reward: {:0.4f}\".format(mean_final_reward[episode]) )\n", " print(\"return standard deviation: {:0.4f}\".format(std_final_reward[episode]) )\n", " print(\"min return: {:0.4f}; max return: {:0.4f}\\n\".format(min_final_reward[episode], max_final_reward[episode]) )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the training curves\n", "\n", "Plot the mean final reward at each episode, and its variance. What do you observe?" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "