In this notebook, our goal is to implement the REINFORCE algorithm for policy gradient using JAX. We will apply this RL algorithm to control a single quantum bit of information (qubit).
The reinforcement learning objective $J$ is the expected total return, following the policy $\pi$. If the transition probability is denoted by $p(s'|s,a)$, and the initial state distribution is $p(s_0)$, the probability for a trajectory $\tau = (s_0,a_0,r_1,s_1,a_1,\dots,s_{T-1},a_{T-1},r_T,s_T)$ to occur can be written as
$$ P_\pi(\tau) = p(s_0)\prod_{t=1}^T \pi(a_t|s_t)p(s_{t+1}|s_t,a_t). $$The RL ojbective then takes the form
$$ J = \mathrm{E}_{\tau\sim P_\pi} \left[ G(\tau) | S_{t=0}=s_0 \right],\quad G(\tau)=\sum_{t=1}^T r(s_t,a_t). $$Policy gradient methods in RL approximate directly the policy $\pi\approx\pi_\theta$ using a variational ansatz, parametrized by the unknown parameters $\theta$. The goal is then to find those optimal parameters $\theta$, which optimize the RL objective $J(\theta)$. To define an update rule for $\theta$, we may use gradient ascent. This requires us to evaluate the gradient of the RL objective w.r.t. the parameters $\theta$:
$$ \nabla_\theta J(\theta) = \nabla_\theta \mathrm{E}_{\tau\sim P_\pi} \left[ \sum_{t=1}^T r(s_t,a_t) | S_{t=0}=s_0 \right] = \int\mathrm{d}\tau \nabla_\theta P_{\pi_\theta}(\tau) G(\tau). $$In a model-free setting, we don't have access to the transition probabilities $p(s'|s,a)$ and this requires us to be able to estimate the gradients from samples. This can be accomplished by noticing that $\nabla_\theta P_{\pi_\theta} = P_{\pi_\theta} \nabla_\theta \log P_{\pi_\theta}$ (almost everywhere, i.e. up to a set of measure zero):
$$ \nabla_\theta J(\theta) = \int\mathrm{d}\tau \nabla_\theta P_{\pi_\theta}(\tau) G(\tau) = \int\mathrm{d}\tau P_{\pi_\theta}(\tau) \nabla_\theta \log P_{\pi_\theta}(\tau) G(\tau) = \mathrm{E}_{\tau\sim P_\pi} \left[\nabla_\theta \log P_{\pi_\theta}(\tau) G(\tau)\right]. $$Since the initial state distribution and the transition proabilities are independent of $\theta$, using the definition of $P_{\pi_\theta}$, we see that $\nabla_\theta P_{\pi_\theta}(\tau) = \nabla_\theta \pi_\theta(\tau)$ where $\pi_\theta(\tau) = \prod_{t=1}^T \pi(a_t|s_t)$.
We can now use MC to estimate the gradients directly from a sample of trajectories $\{\tau_j\}_{j=1}^N$: $$ \nabla_\theta J(\theta) = \mathrm{E}_{\tau\sim P_\pi} \left[\nabla_\theta \log P_{\pi_\theta}(\tau) G(\tau)\right] \approx \frac{1}{N}\sum_{j=1}^N \nabla_\theta \log \pi_\theta(\tau_j) G(\tau_j). $$
In class, we discussed problems that arise due to large the variance of the gradient estimate. In particular, we showed that one can use causality and a baseline to reduce variance. The PG update then rakes the form
$$ \nabla_\theta J(\theta) \approx \frac{1}{N}\sum_{j=1}^N \sum_{t=1}^T \nabla_\theta \log \pi_\theta(a^j_t|s^j_t) \left[\sum_{t'=t}^T r(a^j_{t'}|s^j_{t'})) - b\right]. $$The corresponding gradient ascent update rule reads as
$$ \theta \leftarrow \theta + \alpha \nabla_\theta J(\theta), $$for some step size (or learning rate) $\alpha$.
Let us recall the qubit environment we defined in Notebook 2.
The state of a qubit $|\psi\rangle\in\mathbb{C}^2$ is modeled by a two-dimensional complex-valued vector with unit norm: $\langle\psi|\psi\rangle:=\sqrt{|\psi_1|^2+|\psi_2|^2}=1$. Every qubit state is uniquely described by two angles $\theta\in[0,\pi]$ and $\varphi\in[0,2\pi)$:
\begin{eqnarray} |\psi\rangle= \begin{pmatrix} \psi_1 \\ \psi_2 \end{pmatrix}= \mathrm{e}^{i\alpha} \begin{pmatrix} \cos\frac{\theta}{2} \\ \mathrm{e}^{i\varphi}\sin\frac{\theta}{2} \end{pmatrix} \end{eqnarray}The overall phase $\alpha$ of a single quantum state has no physical meaning. Thus, any qubit state can be pictured as an arrow on the unit sphere (called the Bloch sphere) with coordinates $(\theta,\varphi)$.
To operate on qubits, we use quantum gates. Quantum gates are represented as unitary transformations $U\in \mathrm{U(2)}$, where $\mathrm{U(2)}$ is the unitary group. Gates act on qubit states by matrix multiplication to transform an input state $|\psi\rangle$ to the output state $|\psi'\rangle$: $|\psi'\rangle=U|\psi\rangle$. For this problem, we consider four gates
\begin{equation} U_0=\boldsymbol{1},\qquad U_x=\mathrm{exp}(-i\delta t \sigma^x/2),\qquad U_y=\mathrm{exp}(-i\delta t \sigma^y/2),\qquad U_z=\mathrm{exp}(-i\delta t \sigma^z/2), \end{equation}where $\delta t$ is a fixed time step, $\mathrm{exp}(\cdot)$ is the matrix exponential, $\boldsymbol{1}$ is the identity, and the Pauli matrices are defined as
\begin{equation} \boldsymbol{1}=\begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} ,\qquad \sigma^x=\begin{pmatrix} 0 & 1 \\ 1 & 0 \end{pmatrix} ,\qquad \sigma^y=\begin{pmatrix} 0 & -i \\ i & 0 \end{pmatrix} ,\ \qquad \sigma^z=\begin{pmatrix} 1 & 0 \\ 0 & -1 \end{pmatrix} \end{equation}To determine if a qubit, described by the state $|\psi\rangle$, is in a desired target state $|\psi_\mathrm{target}\rangle$, we compute the fidelity
\begin{eqnarray} F=|\langle\psi_\mathrm{target}|\psi\rangle|^2 = |(\psi_\mathrm{target})^\ast_1 \psi_1 + (\psi_\mathrm{target})^\ast_2 \psi_2|^2,\qquad F\in[0,1] \end{eqnarray}where $\ast$ stands for complex conjugation. Physically, the fidelity corresponds to the angle between the arrows representing the qubit state on the Bloch sphere (we want to maximize the fidelity but minimize the angle between the states).
Now, let us define an episodic RL environment, which contains the laws of physics that govern the dynamics of the qubit (i.e. the application of the gate operations to the qubit state). Our RL agent will later interact with this environment to learn how to control the qubit to bring it from an initial state to a prescribed target state.
We define the RL states $s=(\theta,\varphi)$ as an array containing the Bloch sphere angles of the quantum state. Each step within an episode, the agent can choose to apply one out of the actions, corresponding to the four gates $(\boldsymbol{1},U_x,U_y,U_z)$. We use the instantaneous fidelity w.r.t. the target state as a reward: $r_t=F=|\langle\psi_\ast|\psi(t)\rangle|^2$:
state space: $\mathcal{S} = \{(\theta,\varphi)|\theta\in[0,\pi],\varphi\in[0,2\pi)\}$. Unlike in Notebook 2, there are no terminal states here. Instead, we consider a fixed number of time steps, after which the episode terminates deterministically. The target state (i.e. the qubit state we want to prepare) is $|\psi_\mathrm{target}\rangle=(1,0)^t$: it has the Bloch sphere coordinates $s_\mathrm{target}=(0,0)$.
action space: $\mathcal{A} = \{\boldsymbol{1},U_x,U_y,U_z\}$. Actions act on RL states as follows:
reward space: $\mathcal{R}=[0,1]$. We use the fidelity between the next state $s'$ and the terminal state $s_\mathrm{target}$ as a reward at every episode step:
$$r(s,s',a)= F = |\langle\psi_\mathrm{target}|U_a|\psi(s)\rangle|^2=|\langle\psi_\mathrm{target}|\psi(s')\rangle|^2$$for all states $s,s'\in\mathcal{S}$ and actions $a\in\mathcal{A}$.
import numpy as np
import import_ipynb
from Notebook_2_RL_environments import QubitEnv2 # import environment, notebooks must be in same directory
# set seed of rng (for reproducibility of the results)
n_time_steps = 60
seed=0
np.random.seed(seed)
# create environment class
env=QubitEnv2(n_time_steps, seed=seed)
The implementation of PG follows similar steps as the MNIST problem from Notebook 7:
Use JAX to construct a feed-forward fully-connected deep neural network with neuron acrchitecture $(M_s, 512, 256, |\mathcal{A}|)$, where there are $512$ ($256$) neurons in the first (second) hidden layer, respectively, and $M_s$ and $|\mathcal{A}|$ define the input and output sizes.
The input data into the neural network should have the shape input_shape = (-1, n_time_steps, M_s)
, where M_s
is the number of features/components in the RL state $s=(\theta,\varphi)$. The output data should have the shape output_shape = (-1, n_time_steps, abs_A)
, where abs_A
$=|\mathcal{A}|$. In this way, we can use the neural network to process simultaneously all time steps and MC samples, generated in a single training iteration.
Check explicitly the output shape and test that the network runs on some fake data (e.g. a small batch of vectors of ones with the appropriate shape).
import jax.numpy as jnp # jax's numpy version with GPU support
from jax import random # used to define a RNG key to control the random input in JAX
from jax.experimental import stax # neural network library
from jax.experimental.stax import Dense, Relu, LogSoftmax # neural network layers
# set key for the RNG (see JAX docs)
rng = random.PRNGKey(seed)
# define functions which initialize the parameters and evaluate the model
initialize_params, predict = stax.serial(
### fully connected DNN
Dense(512), # 512 hidden neurons
Relu,
Dense(256), # 256 hidden neurons
Relu,
Dense(env.n_actions), # 4 output neurons
LogSoftmax # NB: computes the log-probability
)
# initialize the model parameters
input_shape = (-1,env.n_time_steps,2) # -1: number of MC points, number of time steps, size of state vector
output_shape, inital_params = initialize_params(rng, input_shape) # fcc layer 28x28 pixes in each image
print('\noutput shape of the policy network is {}.\n'.format(output_shape))
# test network
states=np.ones((3,env.n_time_steps,2), dtype=np.float32)
predictions = predict(inital_params, states)
# check the output shape
print(predictions.shape)
In class we can defined a scalar pseudoloss function, whose gradients give $\nabla_\theta J(\theta)$. Note that this pseudoloss does NOT correspond to the RL objective $J(\theta)$: the difference stems from the fact that the two operations of taking the derivative and performing the MC approximation are not interchangeable (do you see why?).
$$ J_\mathrm{pseudo}(\theta) = \frac{1}{N}\sum_{j=1}^N \sum_{t=1}^T \log \pi_\theta(a^j_t|s^j_t) \left[\sum_{t'=t}^T r(a^j_{t'}|s^j_{t'})) - b_t\right],\qquad b_t = \frac{1}{N}\sum_{j=1}^N G_t(\tau_j). $$The baseline is a sample average of the reward-to-go (return) from time step $t$ onwards: $G_t(\tau_j) = \sum_{t'=t}^N r(s^j_{t'},s^j_{t'})$ .
Because we will be doing gradient ascent, do NOT forget to add an extra minus sign to the output ot the pseudoloss (or else your agent will end up minimizing the return).
Below, we also add an L2 regularizer to the pseudoloss function to prevent overfitting.
### define loss and accuracy functions
from jax import grad
from jax.tree_util import tree_flatten # jax params are stored as nested tuples; use this to manipulate tuples
def l2_regularizer(params, lmbda):
"""
Define l2 regularizer: $\lambda \ sum_j ||theta_j||^2 $ for every parameter in the model $\theta_j$
"""
return lmbda*jnp.sum(jnp.array([jnp.sum(jnp.abs(theta)**2) for theta in tree_flatten(params)[0] ]))
def pseudo_loss(params, trajectory_batch):
"""
Define the pseudo loss function for policy gradient.
params: object(jax pytree):
parameters of the deep policy network.
trajectory_batch: tuple (states, actions, returns) containing the RL states, actions and returns (not the rewards!):
states: np.array of size (N_MC, env.n_time_steps,2)
actions: np.array of size (N_MC, env.n_time_steps)
returns: np.array of size (N_MC, env.n_time_steps)
Returns:
-J_{pseudo}(\theta)
"""
# extract data from the batch
states, actions, returns = trajectory_batch
# compute policy predictions
preds = predict(params, states)
# combute the baseline
baseline = jnp.mean(rewards, axis=0)
# select those values of the policy along the action trajectory
preds_select = jnp.take_along_axis(preds, jnp.expand_dims(actions, axis=2), axis=2).squeeze()
# return negative pseudo loss function (want to maximize reward with gradient DEscent)
return -jnp.mean(jnp.sum(preds_select * (returns - baseline) )) + l2_regularizer(params, 0.001)
Define the optimizer and the update
function which computes the gradient o the pseudo-loss function and performs the update.
We use the Adam optimizer here with step_size = 0.001
and the rest of the parameters have default values.
### define generalized gradient descent optimizer and a function to update model parameters
from jax.experimental import optimizers # gradient descent optimizers
from jax import jit
step_size = 0.001 # step size or learning rate
# compute optimizer functions
opt_init, opt_update, get_params = optimizers.adam(step_size)
# define function which updates the parameters using the change computed by the optimizer
@jit # Just In Time compilation speeds up the code; requires to use jnp everywhere; remove when debugging
def update(i, opt_state, batch):
"""
i: int,
counter to count how many update steps we have performed
opt_state: object,
the state of the optimizer
batch: np.array
batch containing the data used to update the model
Returns:
opt_state: object,
the new state of the optimizer
"""
# get current parameters of the model
current_params = get_params(opt_state)
# compute gradients
grad_params = grad(pseudo_loss)(current_params, batch)
# use the optimizer to perform the update using opt_update
return opt_update(i, grad_params, opt_state)
Finally, we implement the REINFORCE algorithm for policy gradient. Follow the steps below
N_episodes
, and the batch size N_MC
.state
, and the states
, actions
, returns
triple which defines the trajectory batch. mean_final_reward
, std_final_reward
, min_final_reward
, and , max_final_reward
.opt_init
function. Loop over the episodes; for every episode:
3.1 get the current Network parameters
3.2 loop to collect MC samples
3.2.1 reset the `env` and roll out the policy until the episode is over; collect the trajectory data
3.2.2 compute the returns (rewards to go)
3.3 compile the PG data into a trajctory batch
3.4 use the update
function to update the network parameters
3.5 print instantaneous performance
### Train model
import time
# define number of training episodes
N_episodes = 201
N_MC = 64 #128
# preallocate data using arrays initialized with zeros
state=np.zeros((2,), dtype=np.float32)
states = np.zeros((N_MC, env.n_time_steps,2), dtype=np.float32)
actions = np.zeros((N_MC, env.n_time_steps), dtype=np.int)
returns = np.zeros((N_MC, env.n_time_steps), dtype=np.float32)
# mean reward at the end of the episode
mean_final_reward = np.zeros(N_episodes, dtype=np.float32)
# standard deviation of the reward at the end of the episode
std_final_reward = np.zeros_like(mean_final_reward)
# batch minimum at the end of the episode
min_final_reward = np.zeros_like(mean_final_reward)
# batch maximum at the end of the episode
max_final_reward = np.zeros_like(mean_final_reward)
print("\nStarting training...\n")
# set the initial model parameters in the optimizer
opt_state = opt_init(inital_params)
# loop over the number of training episodes
for episode in range(N_episodes):
### record time
start_time = time.time()
# get current policy network params
current_params = get_params(opt_state)
# MC sample
for j in range(N_MC):
# reset environment to a random initial state
#env.reset(random=False) # fixed initial state
env.reset(random=True) # Haar-random initial state (i.e. uniformly sampled on the sphere)
# zero rewards array (auxiliary array to store the rewards, and help compute the returns)
rewards = np.zeros((env.n_time_steps, ), dtype=np.float32)
# loop over steps in an episode
for time_step in range(env.n_time_steps):
# select state
state[:] = env.state[:]
states[j,time_step,:] = state
# select an action according to current policy
pi_s = np.exp( predict(current_params, state) )
action = np.random.choice(env.actions, p = pi_s)
actions[j,time_step] = action
# take an environment step
state[:], reward, _ = env.step(action)
# store reward
rewards[time_step] = reward
# compute reward-to-go
returns[j,:] = jnp.cumsum(rewards[::-1])[::-1]
# define batch of data
trajectory_batch = (states, actions, returns)
# update model
opt_state = update(episode, opt_state, trajectory_batch)
### record time needed for a single epoch
episode_time = time.time() - start_time
# check performance
mean_final_reward[episode]=jnp.mean(returns[:,-1])
std_final_reward[episode] =jnp.std(returns[:,-1])
min_final_reward[episode], max_final_reward[episode] = np.min(returns[:,-1]), np.max(returns[:,-1])
# print results every 10 epochs
#if episode % 5 == 0:
print("episode {} in {:0.2f} sec".format(episode, episode_time))
print("mean reward: {:0.4f}".format(mean_final_reward[episode]) )
print("return standard deviation: {:0.4f}".format(std_final_reward[episode]) )
print("min return: {:0.4f}; max return: {:0.4f}\n".format(min_final_reward[episode], max_final_reward[episode]) )
Plot the mean final reward at each episode, and its variance. What do you observe?
import matplotlib
from matplotlib import pyplot as plt
# static plots
%matplotlib inline
### plot and examine learning curves
episodes=list(range(N_episodes))
plt.plot(episodes, mean_final_reward, '-k', label='mean final reward' )
plt.fill_between(episodes,
mean_final_reward-0.5*std_final_reward,
mean_final_reward+0.5*std_final_reward,
color='k',
alpha=0.25)
plt.plot(episodes, min_final_reward, '--b' , label='min final reward' )
plt.plot(episodes, max_final_reward, '--r' , label='max final reward' )
plt.xlabel('episode')
plt.ylabel('final reward')
plt.legend(loc='lower right')
plt.grid()
plt.show()
Try out different batch sizes and hyperparameters (including different network architectures). Can you improve the performance?
Explore the final batch of trajectories. Check the sequence of actions. Can you make sense of the solution found by the agent? Hint: think of the dynamics on the Bloch sphere and try to visualize the trajectory there.
Compare the Policy Gradient method to conventional optimal control: can optimal control give you a control protocol that works for all states? Why or why not?
Take one of the high-reward tranjectories in the final batch of data. Now perturb it manually at a few time steps in the first half of the protocol such that it no longer produces an optimal reward (you would have to add a function to the environment which evalues a given trajectory). Last, use the policy to see how it would react to those perturbations in real time. Will it correct on-line for the introduced mistakes (i.e. before the opisode is over)?
Find ways to visualize the policy. What is a meaningful way to do that?
What is the initial state distribution $p(s_0)$ in the impelementation above? Check the performance of the PG algorith if $p(s_0)$ is
Introduce small Gaussian noise to the rewards, e.g. $r(s,a) \to r(s,a) + \delta r$ where $\delta r \sim \mathcal{N}(0,\delta)$ for some noise strength $\delta$. Does this lead to a serious performance drop as you vary $\delta\in[0,0.5]$? Why or why not?
The loop over the $N_{MC}$ trajectories slows down the algorithm significantly. Consider ways to speed up the evaluation of a single PG iteration. This may include a modification of the environment QubitEnv2
or the use of parallelization software (see JAX's function vmap
and pmap
).
Change the environment QubitEnv2
to define a nonepisodic task. Additionally, introduce a "stop" action so that when the agent bring the RL state close to $s_\mathrm{target}$ the episode comes to an end and the environment is reset. This would require you to also modify the Policy Gradient implementation above because episodes now can have different length.