Actor-Critic (AC) Methods

The goal of this Notebook is to become familiar with actor-critic (AC) methods. We will do this by coding up the AC algorithm to solve the Cart Pole problem in RL.

Recall from class that AC methods represent an extension of Policy Gradient methods designed to lower the variance of the policy gradient estimate. Moreover, they provide a natural way to apply policy gradient learning on-line, i.e. perform policy updates before the episode has come to an end.

Basic Theory

We have seen that the policy gradient update remains invariant if a baseline $b$ is subtracted from the reward:

$$ \nabla_\theta J(\theta) = \sum_{\{\tau_j\}} \sum_{t=1}^T \pi_\theta(a^j_t|s^j_t) \nabla_\theta\log\pi_\theta(a^j_t|s^j_t) \left(\sum_{t'=t}^T \gamma^{t'-t} r(s^j_{t'},a^j_{t'}) - b \right) $$

where $\pi_\theta$ is the policy, parametrized by the unknown parameters $\theta$, $r(s,a)$ is the reward function, and $\sum_{\{\tau_j\}}$ is the sum over all trajectories. In particular, this invariance also holds true when the baseline is state-dependent, i.e. $b=b(s)$.

The idea behind actor-critic methods is to introduce a second estimator, parametrized by $\varphi$, which estimates the expected return in state $s$ following the policy $\pi_\theta$. The expected return is known as the value function $V_\varphi(s)$. Note that the parameters $\varphi$ are, in general, independent from the parameters of the policy $\theta$ (although some parameters can be shared, if it is believed that $\pi_\theta$ and $V_\varphi$ are to depend on shared common features present in the states $s$).

Using a single-sample estimate for the expected return under the transition probability, we showed in class that the policy gradient can be re-written with the help of the approximate advantage function

$$ A^\pi(s_t,a_t) \approx r(s_t,a_t) + \gamma V^\pi(s_{t+1}) - V^\pi(s_t). $$

where $\gamma$ is the discount factor.

The actor-critic updates thus take the form: $$ \varphi \leftarrow \varphi - \beta\; \mathrm{argmin}_{ \varphi } \frac{1}{2}\sum_{j=1}^N \sum_{t=1}^T ||V^\pi_\varphi(s^j_t) - y^j_t ||^2 \\ \theta \leftarrow \theta + \alpha\; \sum_{\{\tau_j\}} \sum_{t=1}^T \pi_\theta(a^j_t|s^j_t) \nabla_\theta\log\pi_\theta(a^j_t|s^j_t) A^\pi(s^j_t,a^j_t) $$

with the step sizes $\alpha,\beta\in[0,1]$. We discussed two possible estimates for $y^j_t$:

  1. MC estimate: $y^j_t = \sum_{t'=t}^T r(s^j_t,a^j_t)$
  2. Bootstrap/Temporal Difference (TD) estimate: $y^j_t = r(s^j_t,a^j_t) + \gamma V^\pi_\varphi(s^j_{t+1})$

Actor-Critic Algorithms

In class, we derived two AC algorithms, which we now recap.

Offline Actor-Critic Algorithm

The offline AC algorithm, also known as Policy Gradient with Value Function Estimation, can be defined using either of the MC and the Bootstrap/TD estimates. The pseudocode reads as

  1. Sample $\{s^j,a^j\}$ from $\pi_\theta$ (go until the end of episode for each trajectory) (--> offline).

  2. Fit the value function $V^\pi_\varphi(s)$ to the sampled data using the mean-square loss $\mathcal{L}_\mathrm{critic}(\varphi)$ and either of the MC or Bootstrap/TD estimates:

$$ \mathcal{L}_\mathrm{critic}(\varphi) = \frac{1}{2}\sum_{j=1}^N \sum_{t=1}^T ||V^\pi_\varphi(s^j_t) - y^j_t ||^2. $$
  1. Evaluate the advantage function on the sample:
$$ A^\pi(s^j_t,a^j_t) \approx r(s^j_t,a^j_t) + \gamma V^\pi(s^j_{t+1}) - V^\pi(s^j_t). $$
  1. Compute the policy gradient on the sample:
$$ \nabla_\theta J(\theta) \approx \frac{1}{N}\sum_{j=1}^N \sum_{t=1}^T \nabla_\theta\log\pi_\theta(a^j_t|s^j_t) A^\pi(s^j_t,a^j_t). $$
  1. Update the policy:
$$ \theta \leftarrow \theta + \alpha\; \nabla_\theta J(\theta). $$

Online Actor-Critic Algorithm

The online AC algorithm can be defined only using the Bootstrap/TD estimate. The pseudocode reads as

  1. Take action $a\sim\pi_\theta(a|s)$ following policy $\pi_\theta$, and obtain the transition $(s,a,r,s')$.

  2. Update $V^\pi_\varphi(s)$ using the Bootstrap/TD target $y(s) = r(s,a) + \gamma V^\pi_\varphi(s')$, and the cost function $\mathcal{L}_\mathrm{critic}(\varphi)$:

$$ \mathcal{L}_\mathrm{critic}(\varphi) = \frac{1}{2} ||V^\pi_\varphi(s) - y(s)||^2. $$
  1. Compute the advantage function for the transition:
$$ A^\pi(s,a) \approx r(s,a) + \gamma V^\pi(s') - V^\pi(s). $$
  1. Compute the policy gradient for the transition (no sums over trajctories and time-steps) (--> online):
$$ \nabla_\theta J(\theta) \approx \nabla_\theta\log\pi_\theta(a|s) A^\pi(s,a). $$
  1. Update the policy:
$$ \theta \leftarrow \theta + \alpha\; \nabla_\theta J(\theta). $$

Cart Pole Environment

We will apply AC methods on the Cartpole problem, which defines a discounted, non-episodic task.

A pole is attached by an un-actuated joint to a cart, which moves along a frictionless track. The system is controlled by applying a force of +1 or -1 to the cart. The pendulum starts upright, and the goal is to prevent it from falling over. A reward of +1 is provided for every timestep that the pole remains upright. The episode ends when the pole is more than 15 degrees from vertical, or the cart moves more than 2.4 units from the center.

Let us instantiate and visualize the Cart Pole environment.

State (or Observation) and Action spaces for the Cartpole problem

State: 
    Type: Box(4)
    Num State                    Min            Max
    0   Cart Position             -4.8            4.8
    1   Cart Velocity             -Inf            Inf
    2   Pole Angle                 -24°           24°
    3   Pole Velocity At Tip      -Inf            Inf

Action:
    Type: Discrete(2)
    Num Action
    0   Push cart to the left
    1   Push cart to the right

Rewards

As noted above, the reward is +1 at every timestep that the pole remains upright. Since our goal is to find a policy which prevents the pendulum from tipping over, we are presented with a non-episodic task. Therefore, we need an extra condition to define when the task is considered solved.

We render the task solved if the total return within an episode, running-averaged over previous episodes, exceeds a certain cutoff (see variable return_solved). The running average is defined by the formula:

running_return = 0.05 * episode_return + (1 - 0.05) * running_return

Additionally, we also put a large cutoff for the maximal number of steps per episode, see variable max_steps_per_episode above.

Actor-Critic Network

Since the state space is continuous, we can use a deep neural network as a function approximator. We will learn from physical quantities (such as positions, velocities, and angles), not images, and thus we shall focus on an architecture consisting of fully-connected layers.

In order to enable the value function $V_\varphi$ and policy $\pi_\theta$ networks to share common features, we adopt the following architecture, discussed in class:

  1. One common base layer with parameters shared by both $V_\varphi$ and $\pi_\theta$, followed by
  2. Two independent head leayers, consisting of a $V_\varphi$ head and a $\pi_\theta$ head, which do not share parameters.

Thus, whenever the value function is updated, gradients are pushed thru the $V_\varphi$-head and the common layer. Similarly, a policy update changes the $\pi_\theta$-head and the common layer.

The output of the neural network should be a list: the zeroth entry of the list contains the log-probability for the poicy, and the first entry -- the value function estimate.

This architecture can be implemented in JAX, by using the stax.serial and stax.parallel modules. stax.serial stacks neural and activation layers on top of each other; stax.parallel puts layers next to each other. To implement splitting the pipeline into parallel heads, we use the stax.FanOut layer (this works similar to stax.Flatten that we used to flatten the output of convolutional layers so it can be fed into a fully-connected layer).

To construct the network, note that the base and heads layers appear in series, becase the base is shared. The base layer should have 128 neurons, followed by a ReLu activation function. The heads layer itself contains the two heads in parallel. While the $V_\varphi$-head has a single number as an output and does not contain any activation functions, the $\pi_\theta$-head has as many outputs as there are actions to take, followed by the LogSoftmax activation; thus, to build the $\pi_\theta$-head, one has to stack in series a Dense layer followed by the LogSoftmax activation.

  1. Construct the deep neural network, and test it on a sample dataset.
  2. Make sure you understand the output of the network, including the meansing of the shapes/sizes of the ourput.

(Pseudo-) Loss Function

Let us denote the parameters of the common base layer by $\eta$, the policy head parameters -- by $\theta$, and the value function head parameters -- by $\varphi$.

To appreciate the variance reduction offered by AC algorithms, we will implement the offline AC method using a single trajectory to estimate the network gradients. Because the trajectory length can vary (non-episodic task), we use an average over the timesteps within the trajectory.

For the critic loss $\mathcal{L}_\mathrm{critic}(\eta,\varphi)$, we use discounted MC estimates $y_t = \sum_{t'=t} \gamma^{t'-t} r(s_{t'},s_{t'})$, and a Huber loss to cut off excessively large gradients:

$$ \mathcal{L}_\mathrm{critic}(\eta,\varphi) = \frac{1}{T}\sum_{t=1}^T \mathrm{Huber}\left(V^\pi_{\eta,\varphi}(s_t), y_t\right). $$

Further, the policy pseudo-loss function is given by (keeping in mind the negative sign required by gradient ascent)

$$ \mathcal{L}_\mathrm{actor}(\eta,\theta) = - \frac{1}{T}\sum_{t=1}^T \log(\pi_{\eta,\theta}(a_t,s_t) \left(\sum_{t'=t}^T \gamma^{t'-t} r(s_t,a_t) - V^\pi_{\eta,\varphi}(s_t) \right). $$

Note that no gradient should be pushed thru the critic, $V^\pi_{\eta,\varphi}$, here, and hence $\mathcal{L}_\mathrm{actor}$ is not considered a function of $\varphi$.

Finally, we also use an L2 regularizer on all network parameters

$$ L^2_\mathrm{reg}(\eta,\theta,\varphi) = \lambda\left(\sum_{l} ||\eta_l||^2 + \sum_{m} ||\theta_m||^2 + \sum_{n} ||\varphi_n||^2\right), $$

with $\lambda=0.001$ the regularization strength.

For simplicity, we perform steps 2, 4, and 5 of the offline AC algorithm together. This is enabled by JAX, which can push the gradients thru the parameters $(\eta,\theta,\varphi)$ at once, using the total cost function

$$ \mathcal{L}_\mathrm{AC}(\eta,\theta,\varphi) = \mathcal{L}_\mathrm{actor}(\eta,\theta) + \mathcal{L}_\mathrm{critic}(\eta,\varphi) + L^2_\mathrm{reg}(\eta,\theta,\varphi; \lambda). $$

Let us implement the above instructions:

  1. Define the Huber Loss function

  2. Define the L2 Regularizer

  3. Define the total Actor-Critic loss function for a single trajectory. The function body should contain the calculation of the different loss controbutions.

Define generalized gradient descent optimizer

Define the optimizer and the update function which computes the gradient of the pseudo-loss function and performs the update.

We use the Adam optimizer here with step_size = 0.01 and the rest of the parameters have default values. Since both the actor and the critic are encoded using the same network, we can use a single step size.

Offline Actor-Critic Algorithm

Finally, write down the offline AC algorithm.

Recall that we want to use single-trajectory estimates of the neural network gradients.

Moreover, keep in mind that trajectories do not have a fixed length here, so consider using lists instead of arrays.

Questions

  1. Plot the training curve: running return vs episode number.

  2. Check the learned policy: does it make sense physicaly?

  3. Modify the network architecture to use two completely independent networks for the policy and the value function. Note that this allows us to use two optimizers, i.e. two independent learning rates. Compare the performance.

  4. Modify the code to implement the online AC algorithm.

  5. Try solving the Cart Pole problem using the bare images for states instead of the physical quantities.