{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Deep Learning in JAX\n", "\n", "This notebook provides an introduction to basic Deep Supervised Learning using [JAX](https://jax.readthedocs.io/en/latest/). In particular, we will cover the MNIST classification problem -- one of the first big success stories in ML.\n", "\n", "Our goal is to become familiar with the typical pipelines of Deep Learning in JAX. This includes:\n", "\n", "* building models: shallow and deep networks with nonlinear activations, fully-connected and convolutional layers, etc;\n", "* defining cost functions and figure-of-merit functions which test the model accuracy;\n", "* computing derivatives of model parameters;\n", "* understanding the optimizer pipeline and how to update the model parameters." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classification Problems\n", "\n", "The \"harmonic oscillator\" of Deep Learning is the MNIST problem.\n", "\n", "### The MNIST Dataset\n", "\n", "The MNIST classification problem is one of the classical ML problems for learning classification on high-dimensional data with a fairly sizeable number of examples. Yann LeCun and collaborators collected and processed 70 000 handwritten digits (60 000 are used for training and 10 000 for testing) to produce what became known as one of the most widely used datasets in ML: the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). Each handwritten digit comes in a grayscale square image in the shape of a 28×28 pixel grid. Every pixel takes a value in the range [0, 255], representing 256 nuances of the gray color from black to white. The problem of image classification finds applications in a wide range of fields and is important for numerous industry applications of ML: there have exists a number of much more challenging datasets, such as [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) or [ImageNet](http://www.image-net.org/).\n", "\n", "\n", "### Data Preprocesssing\n", "\n", "The first two codeblocks below download the MNIST dataset from the web, and preprocess the MNIST data. In particular, the data are separated into a training set, and a test set, and the labels are encoded in one-hot form." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "##### download MNIST data and store it in under directory _DATA\n", "\n", "import array\n", "import gzip\n", "import os\n", "from os import path\n", "import struct\n", "import urllib.request\n", "\n", "import numpy as np\n", "\n", "\n", "# path to data directory\n", "_DATA = \"/tmp/jax_example_data/\"\n", "\n", "\n", "def _download(url, filename):\n", " \"\"\"Download a url to a file in the JAX data temp directory.\"\"\"\n", " if not path.exists(_DATA):\n", " os.makedirs(_DATA)\n", " out_file = path.join(_DATA, filename)\n", " if not path.isfile(out_file):\n", " urllib.request.urlretrieve(url, out_file)\n", " print(\"downloaded {} to {}\".format(url, _DATA))\n", "\n", "\n", "def _partial_flatten(x):\n", " \"\"\"Flatten all but the first dimension of an ndarray.\"\"\"\n", " return np.reshape(x, (x.shape[0], -1))\n", "\n", "\n", "def _one_hot(x, k, dtype=np.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return np.array(x[:, None] == np.arange(k), dtype)\n", "\n", "\n", "def mnist_raw():\n", " \"\"\"Download and parse the raw MNIST dataset.\"\"\"\n", " # CVDF mirror of http://yann.lecun.com/exdb/mnist/\n", " base_url = \"https://storage.googleapis.com/cvdf-datasets/mnist/\"\n", "\n", " def parse_labels(filename):\n", " with gzip.open(filename, \"rb\") as fh:\n", " _ = struct.unpack(\">II\", fh.read(8))\n", " return np.array(array.array(\"B\", fh.read()), dtype=np.uint8)\n", "\n", " def parse_images(filename):\n", " with gzip.open(filename, \"rb\") as fh:\n", " _, num_data, rows, cols = struct.unpack(\">IIII\", fh.read(16))\n", " return np.array(array.array(\"B\", fh.read()),\n", " dtype=np.uint8).reshape(num_data, rows, cols)\n", "\n", " for filename in [\"train-images-idx3-ubyte.gz\", \"train-labels-idx1-ubyte.gz\",\n", " \"t10k-images-idx3-ubyte.gz\", \"t10k-labels-idx1-ubyte.gz\"]:\n", " _download(base_url + filename, filename)\n", "\n", " train_images = parse_images(path.join(_DATA, \"train-images-idx3-ubyte.gz\"))\n", " train_labels = parse_labels(path.join(_DATA, \"train-labels-idx1-ubyte.gz\"))\n", " test_images = parse_images(path.join(_DATA, \"t10k-images-idx3-ubyte.gz\"))\n", " test_labels = parse_labels(path.join(_DATA, \"t10k-labels-idx1-ubyte.gz\"))\n", "\n", " return train_images, train_labels, test_images, test_labels\n", "\n", "\n", "def mnist(permute_train=False):\n", " \"\"\"Download, parse and process MNIST data to unit scale and one-hot labels.\"\"\"\n", " train_images, train_labels, test_images, test_labels = mnist_raw()\n", "\n", " train_images = _partial_flatten(train_images) / np.float32(255.)\n", " test_images = _partial_flatten(test_images) / np.float32(255.)\n", " train_labels = _one_hot(train_labels, 10)\n", " test_labels = _one_hot(test_labels, 10)\n", "\n", " if permute_train:\n", " perm = np.random.RandomState(0).permutation(train_images.shape[0])\n", " train_images = train_images[perm]\n", " train_labels = train_labels[perm]\n", "\n", " return train_images, train_labels, test_images, test_labels\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Minibatches\n", "\n", "Now that we've written the functions to download and preprocess the MNIST dataset, let's prepare it for training. Since we'll be using some variant of SGD (e.g., ADAM, etc.), we need to feed data into our machine learning model (e.g., a DNN) in random minibatches. \n", "\n", "The function `data_stream` creates a python generator which returns one minibatch of the randomized training set at a time, until all training datapoints are exhausted. To do this efficiently, `data_stream()` defines a python generator. Generators are functions containing loops which `yield` a result one at a time until the loop is exhausted, rather than `return` the output. \n", "\n", "If you're not familiar with generators, explore carefully the code below first. Make sure to explore the effect of the `while`-loop statement which is currently commented out. What is the purpose of having generators? What are generetors good/useful for? Explore the `data_stream()` generator below by printing a few small minibatches. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "first loop:\n", "\n", "0 0\n", "1 1\n", "2 2\n", "3 3\n", "\n", "second loop:\n", "\n", "0 4\n", "1 5\n", "2 6\n", "3 7\n", "4 8\n", "5 9\n" ] }, { "ename": "StopIteration", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mStopIteration\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgen\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mStopIteration\u001b[0m: " ] } ], "source": [ "def my_generator():\n", " #while True:\n", " for j in range(10):\n", " yield j\n", " \n", "gen = my_generator()\n", "print(gen) # shows a generator object\n", "\n", "print('\\nfirst loop:\\n') \n", "\n", "# call generator\n", "for i in range(4):\n", " print(i, next(gen) )\n", " \n", "print('\\nsecond loop:\\n') \n", " \n", "for i in range(10):\n", " print(i, next(gen) )" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "train data: image shape: (60000, 784), label shape: (60000, 10).\n", "test data : image shape: (10000, 784), label shape: (10000, 10).\n", "\n" ] } ], "source": [ "### define minibatches\n", "\n", "# fix seed\n", "seed=0\n", "np.random.seed(seed)\n", "\n", "##### define data variables and the minibatch generator\n", "# load MNIST data\n", "train_images, train_labels, test_images, test_labels = mnist()\n", "\n", "print('\\ntrain data: image shape: {}, label shape: {}.'.format(train_images.shape, train_labels.shape ))\n", "print('test data : image shape: {}, label shape: {}.\\n'.format(test_images.shape, test_labels.shape ))\n", "\n", "\n", "# size of a single minibatch\n", "batch_size=128 \n", "# size of the trining set\n", "num_train = train_images.shape[0] \n", "# define number of complete minibatches (data size need not be muptiple of batch_size)\n", "num_complete_batches, leftover = divmod(num_train, batch_size)\n", "# total number of minibatches is the smallest integer to fit all minibatches in the dataset\n", "num_batches = num_complete_batches + bool(leftover)\n", "\n", "\n", "def data_stream():\n", " \"\"\"\n", " This function defines a generator which produces random batches of data, one at a time.\n", " \n", " \"\"\"\n", " rng = np.random.RandomState(0)\n", " while True:\n", " perm = rng.permutation(num_train) # compute a random permutation\n", " for i in range(num_batches):\n", " batch_idx = perm[i * batch_size:(i + 1) * batch_size]\n", " yield train_images[batch_idx], train_labels[batch_idx]\n", "\n", "# define the batches generator\n", "batches = data_stream()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Plot the Data\n", "\n", "The codeblock below plots one of the datapoints. Pay attention to the title of the plot where the corresponding label in one-hot form is given. Make sure you familiarize yourself with the data size: shapes, dimensions, etc. We said earlier that each pixel takes on the values $[0,1,\\dots,255]$, but below it seems like they are squeeze in the interval $[0,1]$: Where in the code did we do that squeezing? Why did we do it?" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAARcElEQVR4nO3dfYwc9X3H8fcHQgwCp8XxhRjjcuFJ1KkKRCuHyk1IBKUOUgVulci0si6qVYNK1OZBCpZRsVs3KSl5qKM2UUxxw0MwiSBWrJa0UJOSRsFpjuD4oRcTxzLE4NjnEIzdOI6Nv/1j59rl2J19mNmH8+/zkla3O9/57Xxv5z43uzO7O4oIzOzkd0q/GzCz3nDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIGOuySdku6psV5Q9JFHS6n5bFZT0ck3dfJssyakXSNpMOSTrT699+KgQ77APu9iFg8cUPSsKRvSPq5pB+0s4IkTZO0VtLLkn4i6cPtNCLpQ9m4g9n9TGtj7NVZvz/P+j+/jbGrJG2VdFzSyjZ7nqq/c0/Wc0T8e0ScBTzX6v23wmEvxzrgaeCNwG3AQ5KGWhy7ErgYOB94N/BRSQtaGSjpd4FlwNXAMHAB8Jctjp0JfBX4C2AGMAp8ucWeAXYCHwX+pY0xE1YyNX/nvqzn0kTEwF6A3cA12fV5wJPAS8Be4O+B19fMG8CfAbuAA8CdwCk19T8GxoCfAf8GnD9p7EXt9pTdvgQ4CkyvmfafwM0t3t/zwLU1t1cBD7Y49gHg4zW3rwZ+0uLYpcC3a26fCRwBLm1zHd0PrGxzzJT7nfuxnif/rRW9TKUt+yvAh4CZwG9RXcl/OmmehUAFeBtwPdWAI+kGYDnw+8AQ1ZW0rt5CJP2hpC1t9PVWYFdEHKqZ9v1sei5JZwPnZvO3NbZm2ZPHniPpje2OjYj/AX7UxrI7MoV/536u51JMmbBHxFMRsSkijkfEbuALwFWTZvtERLwYEc8BfwfcmE2/CfibiBiLiOPAx4HL671ei4gHIuI322jtLODgpGkHgektjp2Yv92x9ZY9cb3VZXfadxFT9Xfu53ouxZQJu6RLJP1ztnPjZaqBnTlpth/XXH+W6n9TqL5OWi3pJUkvAS8CAmaX0Nph4A2Tpr0BOFRn3npjJ+Zvd2y9ZU9cb3XZnfZdxFT9nfu5nksxZcIOfB74AXBxRLyB6tNyTZpnTs31XwNeyK7/GLgpIn615nJGRHy7hL62AxdIqv0vfVk2PVdE/Izq/ofL2h1bs+zJY/dFxE/bHSvpTODCNpbdkSn8O/dzPZejrBf/3bjw6h10/wXcTjXglwI7gG/VzBvARuBsqqH/AbA0qy0EtgFvzW7/CvDeSWM72kGXTdsEfBI4PVvWS8BQi/d3B/BE1velVP8oFrQ4dgHwE2BuNv5x4I4Wxw5RfSr5B1nfnwA2tbFuTsvGPQD8dXb91JP8d+7peq73t1YoT2XdUTcuk8L+zizAh6nuYPurOmGf2Bv/U+BTtX98wGJgK/Ay1S392kljL8qu/xGwvZ0VQPUQ0H9Q3bO7g1fvrX8HcDjn/qYBa7O+9gEfnlQ/DLwjZ/yHs3EvA/8ETKupfR1YnjP2muwxPZL1P1xTWw58PWfsF7PHrfby/pP8d+7peq73t1bkouxOrUWSdgCzgPURMdLvfuzkI+lq4GGq/yCui4hvlHK/DrtZGqbSDjozK8BhN0vE63q5sJkzZ8bw8HAvF2mWlN27d3PgwIHJh6SBgmHP3si/GjgV+MeIuCNv/uHhYUZHR4ss0sxyVCqVhrWOn8ZLOhX4B+A9VI953ihpbqf3Z2bdVeQ1+zxgZ0TsiohfAg9S/fCJmQ2gImGfzavfi76HOu81l7RU0qik0fHx8QKLM7MiioS93k6A1xy0j4g1EVGJiMrQUKuf8zezshUJ+x5e/cGT8/j/D56Y2YApEvbvAhdLeouk1wOLgA3ltGVmZev40FtEHJf0Aapf8XQq1Q+W9PYje2bWskLH2SPiEeCRknoxsy7y22XNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIgqdslnSbuAQ8ApwPCIqZTRlZuUrFPbMuyPiQAn3Y2Zd5KfxZokoGvYAHpX0lKSl9WaQtFTSqKTR8fHxgoszs04VDfv8iHgb8B7gFknvnDxDRKyJiEpEVIaGhgouzsw6VSjsEfFC9nM/sB6YV0ZTZla+jsMu6UxJ0yeuA9cC28pqzMzKVWRv/DnAekkT9/NARPxrKV3Zqxw7diy3ftdddzWsLVmyJHfstGnTOuppEKxbty63fttttzWs7dq1q+x2Bl7HYY+IXcBlJfZiZl3kQ29miXDYzRLhsJslwmE3S4TDbpaIMj4IY122bNmy3Prq1asb1mbMmJE7dtGiRR31NAj27NmTW88OC1vGW3azRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBE+zj4FjI2N9buFgeTj6O3xlt0sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4SPs/fA0aNHc+uf/exnc+tPPPFEme0MjEOHDuXW77zzztz6jh07ymznpOctu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCB9n74GDBw/m1pt9L3wzV111VcPaIH8v/GOPPZZb/9jHPlbo/kdGRgqNP9k03bJLWitpv6RtNdNmSHpM0g+zn2d3t00zK6qVp/FfBBZMmrYM2BgRFwMbs9tmNsCahj0ivgm8OGny9cA92fV7gBvKbcvMytbpDrpzImIvQPbzTY1mlLRU0qik0fHx8Q4XZ2ZFdX1vfESsiYhKRFSGhoa6vTgza6DTsO+TNAsg+7m/vJbMrBs6DfsGYOK4xgjwtXLaMbNuaXqcXdI64F3ATEl7gBXAHcBXJC0BngPe280mp7pVq1bl1ot+//n8+fMLje+X559/Prde9HE577zzCo0/2TQNe0Tc2KB0dcm9mFkX+e2yZolw2M0S4bCbJcJhN0uEw26WCH/EtQc+97nP5dabHWIaHh7OrS9fvrzdlgZCt09FPXfu3K7e/1TjLbtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulggfZ58Cbr/99tz6GWec0aNO2vf00083rN1///097MS8ZTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuHj7CU4evRobv3EiRO59WanLs47JTPAL37xi4a1008/PXfskSNHcuvHjh3LrT/zzDO59Xnz5jWsFf2q6Dlz5uTWp+pXbHeLt+xmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSJ8nL1FBw4caFi79dZbc8eeckr+/9TFixfn1q+44orc+t69exvWrrzyytyxjz76aG792Wefza03k3csvdlx9gULFuTW169fn1s/7bTTcuupabpll7RW0n5J22qmrZT0vKTN2eW67rZpZkW18jT+i0C9f7GfiYjLs8sj5bZlZmVrGvaI+CbwYg96MbMuKrKD7gOStmRP889uNJOkpZJGJY2Oj48XWJyZFdFp2D8PXAhcDuwFPtVoxohYExGViKgMDQ11uDgzK6qjsEfEvoh4JSJOAHcBjT/aZGYDoaOwS5pVc3MhsK3RvGY2GJoeZ5e0DngXMFPSHmAF8C5JlwMB7AZu6l6Lg2HmzJkNa6tXr84d2+wz5XnH8AEefvjh3HqeLVu25NaLfqa8iBUrVuTWFy1alFv3cfT2NA17RNxYZ/LdXejFzLrIb5c1S4TDbpYIh90sEQ67WSIcdrNEKCJ6trBKpRKjo6M9W95Ucfz48dz6zp07c+sbNmxoWGu2fpsdehsbG8ut33fffbn16dOnN6xt3749d+y5556bW7fXqlQqjI6O1l2p3rKbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZonwV0kPgNe9Ln81XHrppYXqRTz55JO59XvvvTe3PjIy0rDm4+i95S27WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIH2e3XA899FBuvdnn4WfPnl1mO1aAt+xmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSJaOWXzHOBe4M3ACWBNRKyWNAP4MjBM9bTN74uIn3WvVeuGI0eO5NYff/zxQvfv8wQMjla27MeBj0TErwNXArdImgssAzZGxMXAxuy2mQ2opmGPiL0R8b3s+iFgDJgNXA/ck812D3BDl3o0sxK09Zpd0jBwBfAd4JyI2AvVfwjAm0rvzsxK03LYJZ0FPAx8MCJebmPcUkmjkkbHx8c76dHMStBS2CWdRjXoX4qIr2aT90maldVnAfvrjY2INRFRiYjK0NBQGT2bWQeahl3VjzXdDYxFxKdrShuAia8OHQG+Vn57ZlaWVj7iOh9YDGyVtDmbthy4A/iKpCXAc8B7u9KhddWKFSty61u3bi10/6tWrSo03srTNOwR8S2g0YeWry63HTPrFr+DziwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCXyWduEOHDuXWIyK3vnDhwtz6JZdc0nZP1h3espslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmifBx9sQ1O+Vys/q8efPKbMe6yFt2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0Q47GaJcNjNEuGwmyXCYTdLhMNulgiH3SwRTT/PLmkOcC/wZuAEsCYiVktaCfwJMJ7NujwiHulWo9Ydb3/723PrmzZtyq3ffPPNZbZjXdTKl1ccBz4SEd+TNB14StJjWe0zEfHJ7rVnZmVpGvaI2Avsza4fkjQGzO52Y2ZWrrZes0saBq4AvpNN+oCkLZLWSjq7wZilkkYljY6Pj9ebxcx6oOWwSzoLeBj4YES8DHweuBC4nOqW/1P1xkXEmoioRERlaGioeMdm1pGWwi7pNKpB/1JEfBUgIvZFxCsRcQK4C/A3D5oNsKZhV/XrRe8GxiLi0zXTZ9XMthDYVn57ZlaWVvbGzwcWA1slbc6mLQdulHQ5EMBu4KYu9GddNjIyUqhuU0cre+O/BdT78nAfUzebQvwOOrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIRUTvFiaNA8/WTJoJHOhZA+0Z1N4GtS9wb50qs7fzI6Lu97/1NOyvWbg0GhGVvjWQY1B7G9S+wL11qle9+Wm8WSIcdrNE9Dvsa/q8/DyD2tug9gXurVM96a2vr9nNrHf6vWU3sx5x2M0S0ZewS1ogaYeknZKW9aOHRiTtlrRV0mZJo33uZa2k/ZK21UybIekxST/MftY9x16felsp6fnssdss6bo+9TZH0jckjUnaLunPs+l9fexy+urJ49bz1+ySTgWeAX4H2AN8F7gxIv67p400IGk3UImIvr8BQ9I7gcPAvRHxG9m0vwVejIg7sn+UZ0fErQPS20rgcL9P452drWhW7WnGgRuA99PHxy6nr/fRg8etH1v2ecDOiNgVEb8EHgSu70MfAy8ivgm8OGny9cA92fV7qP6x9FyD3gZCROyNiO9l1w8BE6cZ7+tjl9NXT/Qj7LOBH9fc3sNgne89gEclPSVpab+bqeOciNgL1T8e4E197meypqfx7qVJpxkfmMeuk9OfF9WPsNc7ldQgHf+bHxFvA94D3JI9XbXWtHQa716pc5rxgdDp6c+L6kfY9wBzam6fB7zQhz7qiogXsp/7gfUM3qmo902cQTf7ub/P/fyfQTqNd73TjDMAj10/T3/ej7B/F7hY0lskvR5YBGzoQx+vIenMbMcJks4ErmXwTkW9AZg4teoI8LU+9vIqg3Ia70anGafPj13fT38eET2/ANdR3SP/I+C2fvTQoK8LgO9nl+397g1YR/Vp3TGqz4iWAG8ENgI/zH7OGKDe7gO2AluoBmtWn3r7baovDbcAm7PLdf1+7HL66snj5rfLmiXC76AzS4TDbpYIh90sEQ67WSIcdrNEOOxmiXDYzRLxv29bfwpe9en4AAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib\n", "from matplotlib import pyplot as plt\n", "# static plots\n", "%matplotlib inline \n", "\n", "### show the first data point as an example\n", "n=1111 # test data point number\n", "\n", "plt.imshow(test_images[n].reshape(28,28),cmap='Greys')\n", "plt.title('label: {}'.format(test_labels[n]) )\n", "\n", "plt.show()\n", "\n", "# print the array which generates the image above and expore it!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SoftMax Regression\n", "\n", "Our goal is to train a ML model that classifies the images in the test set, by only using the data in the training set. We will explore shallow and deep networks, as well as fully-connected and convolutional layers.\n", "\n", "Instead of learning directly the catecory an imamge falls into, learning is more stable if we model the probability to be in a certain category $i$. To do this, we will use a SoftMax output nonlinearity activaion, which can be thought of as a statistical model that assigns the probability that a given input image corresponds to any of the 10 handwritten digits. This layer is a multi-categorical generalization of the logistic regression and reads as:\n", "\n", "$$ \n", "p\\left(y=i|x,\\{\\boldsymbol{\\theta}_k\\}_{k=0}^9 \\right) = \\frac{\\exp(x^T\\cdot \\boldsymbol{\\theta}_i)}{\\sum_{j=0}^9 \\exp(x^T\\cdot \\boldsymbol{\\theta}_j) }\n", "$$\n", "\n", "Where $p\\left(y=i|x,\\{\\boldsymbol{\\theta}_k\\}_{k=0}^9 \\right)$ is the probability that input $x$ is the $i$-th digit, $i\\in\\{0,1,2,\\dots,9\\}$. The model parameters (sometimes called trainables or learnables) are denoted by $\\boldsymbol{\\theta}$: in the simplest model, there's one $\\theta_j$ per category. One can use this model for prediction by taking that value of $y$ for which the probability is maximized:\n", "\n", "$$\n", "y_\\mathrm{pred} = \\arg\\max_{j} p\\left(y=j|x,\\{\\boldsymbol{\\theta}_k\\}_{k=0}^9 \\right).\n", "$$ \n", "\n", "In practice, it is often more convenient to use the logarithm of the softmax function to learn probabilities. Using a log scale allows the model to more easily capture probabilities which differ by several orders of magnitude. Often, there is a special function for this in ML packages, with improved performance. \n", "\n", "\n", "### Three ML Models\n", "\n", "Below, you have to use the JAX library to build three different models:\n", "\n", "A. SoftMax Logistic Regression: no correlations between pixels, no info about image dimensionality\n", "\n", "B. DNN: a fully-connected (fc) deep neural network: no info about image dimensionality\n", "\n", "C. CNN: a convolutional (conv) neural network with a fully-connected head" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SoftMax Logistic Regression\n", "\n", "We begin with the SoftMax Logistic Regression. The code structure will later allow us to easily generalize and re-use the functions we write to the DNN and CNN models. We will proceed in the following steps:\n", "\n", "1. Define the ML model. \n", "2. Define the `loss` function, and a function which measures the `accuracy` of the model predictions. \n", "3. Define generalized gradient descent optimizer\n", "4. Define the training loop and train the model\n", "\n", "Fill-in the code snippet below. To do this, you will have to explore and read the JAX documentation:\n", "* start with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) and make yourself familiar with the functions in there. Why is there a need to a JAX version of numpy, i.e. why not use the ordinary numpy library? You may als check out `jax.scipy`.\n", "* look up [`jax.random`](https://jax.readthedocs.io/en/latest/jax.random.html): random numbers work a bit differently in JAX, compared to numpy, but this is not hat important for understanding how to do deep learning in JAX.\n", "* neural network architectures, like the fc and conv layers we will need, are described in [`jax.experimental.stax`](https://jax.readthedocs.io/en/latest/jax.experimental.stax.html); make sure you understand the difference between required/compusory and optional arguments. The [`jax.nn`](https://jax.readthedocs.io/en/latest/jax.nn.html) package contains the nonlinear activation functions. To apply the activations elementwise, use either the [`elementwise` function](https://jax.readthedocs.io/en/latest/_modules/jax/experimental/stax.html#elementwise), or capital-letter activations instead. \n", "\n", "\n", "#### Define the ML model\n", "\n", "We will first construct the SoftMax model. To do so, we think of it as a single layer NN. To construct the layer, use the `init_fun, apply_fun = stax.serial()` function. This function returns two other functions: `init_fun, apply_fun` (you may give them whatever names you want). \n", "* `output_shape, params = init_fun(rng, input_shape)` is used to initialize the parameters of the model: it returns the shape of the output at the topmost layer, and a tuple of nested tuples `params` which contains the model parameters. \n", "* `predictions = predict(params, data)` makes use of the model as defined by `params` and the `data` to apply the model on data and produce the model output. \n", "\n", "Use a subset of the training data which contains, say three, data points. We will use it to debug and explore the model we have defined. \n", "\n", "1.0. build a model consisting of a single `Dense` layer, followed by the `LogSoftMax` activation. \n", "\n", "1.1. compute the model predictions on that toy subset\n", "\n", "1.2. check the shape of the output. Does it agree with the `output_shape` tuple returned by `init_fun`?\n", "\n", "1.3. print out the prediction values themselves: how many values does each of the three toyset datapoints have? Do these values represent a well-defined probability distribution? -- check the conservation of probability. \n", "\n", "1.4. print the `params` variable and explore it carefully. It defines a so-called JAX tree. Manipulating this ordered list of nested lists can be very annoying if you want to do that from scratch. Instead, explore and use the functions of the [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) package. Understanding this is especially important if you want to access the parameters of a specific layer in the model. Extract the weights and biases from the single layer model, and check their shape/sizes of the parameters; waht is their data type `type(variable)`; why is this new datatype needed and how can you transfer data from an ordinary numpy datatype to such a datatype back and forth?" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "output shape of the model is (-1, 10).\n", "\n", "actual output shape is: (3, 10)\n", "log(softmax) values: [-2.2292223 -2.4820766 -2.7947464 -2.1780825 -1.9830028 -2.2069645\n", " -1.9345262 -2.6357243 -2.065387 -3.1278682]\n", "conservation of probability [0.9999999 1.0000001 1.0000001]\n" ] } ], "source": [ "import jax.numpy as jnp # jax's numpy version with GPU support\n", "from jax import random # used to define a RNG key to control the random input in JAX\n", "from jax.experimental import stax # neural network library\n", "from jax.experimental.stax import Dense, Relu, LogSoftmax # neural network layers\n", "\n", "# set key for the RNG (see JAX docs)\n", "rng = random.PRNGKey(seed)\n", "\n", "# cast data into 1D image format suitable for fc layers: the shape should be (N_datapoints, 28*28)\n", "train_images = train_images.reshape(-1,28*28) # -1: number of data points, (28*28): (height*width) dimensions of image\n", "test_images = test_images.reshape(-1,28*28)\n", "\n", "# define functions which initialize the parameters and evaluate the model\n", "initialize_params, predict = stax.serial(\n", " ### SoftMax Regression\n", " Dense(10), # 10 output neurons\n", " LogSoftmax # NB: computes the log-probability\n", " \n", "# ### fully connected DNN\n", "# Dense(512), # 512 hidden neurons\n", "# Relu,\n", "# Dense(256), # 256 hidden neurons\n", "# Relu,\n", "# Dense(10), # 10 output neurons\n", "# LogSoftmax # NB: computes the log-probability\n", " )\n", "\n", "# initialize the model parameters\n", "output_shape, inital_params = initialize_params(rng, (-1, 28 * 28)) # fcc layer 28x28 pixes in each image\n", "\n", "print('\\noutput shape of the model is {}.\\n'.format(output_shape))\n", "\n", "# check how network works on 3 examples\n", "predictions = predict(inital_params, test_images[0:3])\n", "\n", "# print shape of output\n", "print(\"actual output shape is:\", predictions.shape)\n", "\n", "# check if probability is conserved\n", "print('log(softmax) values:', predictions[0])\n", "print('conservation of probability', np.sum(jnp.exp(predictions), axis=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define the loss/cost function\n", "\n", "Next, we define the loss/cost function and the accuracy function to measure the performance of the model. Defining these functions in JAX works the same way as in ordinary python. The only difference is that one has to use `jax.numpy` instead of ordinary `numpy` (for ordinary `numpy` is not optimized for GPUs). \n", "\n", "2.0. Complete the `loss(params, batch)` function which computes the cross entropy, given the model `params` and the data of a minibatch `batch`. Using functions like `jnp.mean` or `jnp.sum` can make the `loss` agnostic to the size of the `batch`. \n", "\n", "2.1. Complete the `mean_accuracy(params, batch)` function. It computes the mean number of datapoints which produce a correct preduction.\n", "\n", "2.2. Test the `loss` and `mean_accuracy` functions on the toy dataset. This helps spotting errors and debugging them.\n", "\n", "2.3. (optional): use the `tree_flatten` function from the `jax.tree_util` package and add an `L2` regularizer to the `loss` function. To do so, define another function `l2_regularizer(params, lmbda)` which computes the L2 norm of the model parameters and weighs it by the regulazation strength `lmbda`. Is there a JAX function that computes the L2 norm?\n", "\n", "2.4. explore the [`jax.grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) function. Test it on the loss function by computing the values of the gradient of the model parameters for the points at the toyset. Check the form of the gradient output. Check the shapes/dimensions of the gradients themselves. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "### define loss and accuracy functions\n", "\n", "from jax import grad\n", "from jax.tree_util import tree_flatten # jax params are stored as nested tuples; use this to manipulate tuples\n", "\n", "\n", "def l2_regularizer(params, lmbda):\n", " \"\"\"\n", " Define l2 regularizer: $\\lambda \\ sum_j ||theta_j||^2 $ for every parameter in the model $\\theta_j$\n", " \n", " \"\"\"\n", " return lmbda*jnp.sum(jnp.array([jnp.sum(jnp.abs(theta)**2) for theta in tree_flatten(params)[0] ]))\n", "\n", "\n", "def loss(params, batch):\n", " \"\"\"\n", " Define cost (or lost) function for softmax classification. \n", " \n", " \"\"\"\n", " inputs, targets = batch\n", " preds = predict(params, inputs)\n", " return -jnp.mean(jnp.sum(preds * targets, axis=1)) + l2_regularizer(params, 0.001)\n", "\n", "\n", "def mean_accuracy(params, batch):\n", " \"\"\"\n", " Define accuracy function: the mean number of datapoints which have correct preductions. \n", " This function is not used for training; only to test the performance. \n", " \n", " \"\"\"\n", " inputs, targets = batch\n", " target_class = jnp.argmax(targets, axis=1)\n", " predicted_class = jnp.argmax(predict(params, inputs), axis=1)\n", " return jnp.mean(predicted_class == target_class)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Define the optimizer\n", "\n", "Next, we define the optimizer. Make sure to read the documentation of [`jax.experimental.optimizers`](https://jax.readthedocs.io/en/latest/jax.experimental.optimizers.html). In particular, you may want to check out the source code for \n", "[SGD](https://jax.readthedocs.io/en/latest/_modules/jax/experimental/optimizers.html#sgd), \n", "[SGD with momentum](https://jax.readthedocs.io/en/latest/_modules/jax/experimental/optimizers.html#momentum), and \n", "[ADAM](https://jax.readthedocs.io/en/latest/_modules/jax/experimental/optimizers.html#adam).\n", "\n", "3.0. define the optimizer hyperparameters (step size/learning rate, etc.)\n", "\n", "3.1. call the `optimizers.momentum` constructor and obtain the `opt_init, opt_update, get_params` functions. Makes sure you read the documentation to understand what they do, and how they are used.\n", "\n", "3.2. Complete the `update(i, opt_state, batch)` function. The only nontrivial step is the the computation of the gradient of the loss function, which we explored in part 2.4. above. Which line does the actual update of the model parameters (e.g. the SGD step) take place in?\n", "\n", "3.3. add a `@jit` decorator (Just-In-Time compiler) to the update function. This will make jax compile the `update()` function to give you speed (even on the CPU!). Explore the documentation for [`jax.jit`](https://jax.readthedocs.io/en/latest/jax.html#jax.jit). One caveat is that any functions (and subroutines) used under the `jit` decorator must be using `jax.numpy` or `jax.scipy`; using ordinary `numpy` _and_ `jit` will throw an error (thy that out by modifying the `loss` function!). " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "### define generalized gradient descent optimizer and a function to update model parameters\n", "\n", "from jax.experimental import optimizers # gradient descent optimizers\n", "from jax import jit\n", "\n", "step_size = 0.001 # step size or learning rate \n", "momentum_mass = 0.9 # \"gamma\" parameter in GD+momentum\n", "\n", "# compute optimizer functions\n", "opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)\n", "\n", "\n", "# define function which updates the parameters using the change computed by the optimizer\n", "@jit # Just In Time compilation speeds up the code; requires to use jnp everywhere; remove when debugging\n", "def update(i, opt_state, batch):\n", " \"\"\"\n", " i: int,\n", " counter to count how many update steps we have performed\n", " opt_state: object,\n", " the state of the optimizer\n", " batch: np.array\n", " batch containing the data used to update the model\n", " \n", " Returns: \n", " opt_state: object,\n", " the new state of the optimizer\n", " \n", " \"\"\"\n", " # get current parameters of the model\n", " current_params = get_params(opt_state)\n", " # compute gradients\n", " grad_params = grad(loss)(current_params, batch)\n", " # use the optimizer to perform the update using opt_update\n", " return opt_update(i, grad_params, opt_state)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Train Model\n", "\n", "At last, we have built all ingredients and we can start training our model. We train the model in epochs. In every epoch, we loop over the number of minibatches to exhaust the training set. We update the model parameters for each minibatch (hence the number of epochs is not the same of the number of updates). Therefore, we use a variable `itercount` to count the number of updates; in fact, `itercount` will be a trivial generator similar to the ones discussed above. Once we've done the update, we can read off the model parameters and check the current loss and model accuracy ***on the test set*** for the given epoch. \n", "\n", "Then we move to the second epoch and repeat the procedure. The model learns if the loss on the test data goes down, and the accuracy on the test data goes up. We can monitor these quantities durin training.\n", "\n", "4.0. define placeholders for the `train_accuracy` and `test_accuracy`. \n", "\n", "4.1. initialize the optimizer state using the `opt_init` function. \n", "\n", "4.2. loop over the epochs. \n", "\n", "4.2.1. For each epoch, loop over all minibatches and use the `update()` function to compute the gradients of the `params` and update the model. Updating the model happens automatically upon calling `update()`. How does `update` know about the current value of `params`? Check if `params` is changing after each call of `update`. \n", "\n", "4.2.2. Compute the mean accuracy of the test and traing data, and store it in `train_accuracy` and `test_accuracy`. Print these values for reference. " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Starting training...\n", "\n", "Epoch 0 in 0.52 sec\n", "Training set accuracy 0.8258500099182129\n", "Test set accuracy 0.8352000117301941\n", "\n", "Epoch 10 in 0.10 sec\n", "Training set accuracy 0.8929499983787537\n", "Test set accuracy 0.9003000259399414\n", "\n", "Epoch 20 in 0.12 sec\n", "Training set accuracy 0.9025333523750305\n", "Test set accuracy 0.9090999960899353\n", "\n", "Epoch 30 in 0.16 sec\n", "Training set accuracy 0.9068499803543091\n", "Test set accuracy 0.9121000170707703\n", "\n", "Epoch 40 in 0.13 sec\n", "Training set accuracy 0.9095333218574524\n", "Test set accuracy 0.9143999814987183\n", "\n", "Epoch 50 in 0.14 sec\n", "Training set accuracy 0.911466658115387\n", "Test set accuracy 0.9151999950408936\n", "\n", "Epoch 60 in 0.10 sec\n", "Training set accuracy 0.912933349609375\n", "Test set accuracy 0.9157999753952026\n", "\n", "Epoch 70 in 0.10 sec\n", "Training set accuracy 0.9145166873931885\n", "Test set accuracy 0.916100025177002\n", "\n", "Epoch 80 in 0.13 sec\n", "Training set accuracy 0.914900004863739\n", "Test set accuracy 0.9158999919891357\n", "\n", "Epoch 90 in 0.10 sec\n", "Training set accuracy 0.9155666828155518\n", "Test set accuracy 0.916700005531311\n", "\n", "Epoch 100 in 0.13 sec\n", "Training set accuracy 0.9156666398048401\n", "Test set accuracy 0.9176999926567078\n", "\n" ] } ], "source": [ "### Train model\n", "\n", "import time\n", "import itertools\n", "\n", "# define geenrator to count the number of updates\n", "itercount = itertools.count()\n", "\n", "# define number of training epochs\n", "num_epochs = 101\n", "\n", "# define figures of merit\n", "train_accuracy=np.zeros(num_epochs)\n", "test_accuracy=np.zeros_like(train_accuracy)\n", "\n", "print(\"\\nStarting training...\\n\")\n", "\n", "# set the initial model parameters in the optimizer\n", "opt_state = opt_init(inital_params)\n", "\n", "# loop over the number of training epochs\n", "for epoch in range(num_epochs): \n", " \n", " ### record time\n", " start_time = time.time()\n", " \n", " ### train in minibatches until the entire dataset is exhausted: \n", " # the entire dataset is divided into _random_ minibatches; \n", " # all minibatches are shown to the model before going to next epoch\n", " for _ in range(num_batches):\n", " # use the data to update the model parameters\n", " opt_state = update(next(itercount), opt_state, next(batches))\n", " \n", " ### record time needed for a single epoch\n", " epoch_time = time.time() - start_time\n", " \n", " ### evaluate performance of the model at each fixed epoch\n", " \n", " # retrieve current model parameters\n", " params = get_params(opt_state)\n", " \n", " # measure the accuracy on the training and test datasets\n", " train_accuracy[epoch] = mean_accuracy(params, (train_images, train_labels))\n", " test_accuracy[epoch] = mean_accuracy(params, (test_images, test_labels))\n", " \n", " # print results every 10 epochs\n", " if epoch % 10 == 0:\n", " print(\"Epoch {} in {:0.2f} sec\".format(epoch, epoch_time))\n", " print(\"Training set accuracy {}\".format(train_accuracy[epoch]))\n", " print(\"Test set accuracy {}\\n\".format(test_accuracy[epoch]))\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Explore the training properties and behavior\n", "\n", "Below, we compare the training and test average accuracy curves. " ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "### plot and examine learning curves\n", "\n", "epochs=list(range(num_epochs))\n", "\n", "plt.plot(epochs, train_accuracy, '-b', label='training data' )\n", "plt.plot(epochs, test_accuracy, '-r', label='test data' )\n", "\n", "plt.xlabel('epoch')\n", "plt.ylabel('accuracy')\n", "\n", "plt.grid()\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Examine the trained Weights\n", "\n", "Let's examine the trained weights and check how they look like. This is a first attempt to answer the question what the ML model learns and how it does the classification. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "### plot weights vs the pixel position (works only for a single layer)\n", "### code works only SoftMax regression (!)\n", "\n", "from jax. tree_util import tree_flatten # jax params are stored as nested tuples; use this to manipulate tuples\n", "\n", "# extract weights and biases using tree_flatten\n", "params = get_params(opt_state)\n", "weights, biases = tree_flatten(params)[0]\n", "\n", "# print the weights (biases are not so interesting)\n", "plt.figure(figsize=(15, 7)) # figure size \n", "\n", "scale = np.abs(weights).max() # define overall scale\n", "\n", "for i in range(10): # loop over the number of weights\n", " \n", " plot = plt.subplot(2, 5, i + 1)\n", " plot.imshow(weights[:,i].reshape(28, 28), interpolation='nearest', cmap=plt.cm.Greys, vmin=-scale, vmax=scale)\n", " plot.set_xticks(())\n", " plot.set_yticks(())\n", " plot.set_xlabel('Class %i' % i)\n", " \n", "plt.suptitle('classification weights vector $w_j$ for digit class $j$')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DNN model\n", "\n", "We now want to train a DNN. To do that, all we need to do is add more layers to the model in the `stax.serial()` function, and re-run the training loop. Build a fully-connected DNN model with layer sizes (784, 512, 256, 10), where 784 corresponds to the input layer, and 10 -- to the output layer; i.e. we have two hidden layers. Use a `Relu` nonlinearity after each layer, except for the output where you need the `LogSoftMax`. \n", "\n", "\n", "### CNN model\n", "\n", "The CNN model is basically the same as the DNN model, except is also uses convolutional layers. Convolutional layers know about the dimensionality of the input. In our case, the data is two-domensional, and thus needs to be reshaped accordingly. Suppose the data had color: in color images each pixel has three values in between [0,255] -- one for each of the red, green, and blue channels. Hence, the dataset is a four-dimensiona array with shape `(N_points, N_Channels, Height, Width)`. For black and white images, we just set `N_Channels=1`.\n", "\n", "C.1 reshape the `train_images` and `test_images` to 4-dimensional array. We use the convention `dim_numbers='NCHW'`: (N data points, Channels, Height, Width). \n", "\n", "C.2 use `GeneralConv(dim_numbers, output_channels, filter_size, strides)` layers to add conv layers to the neural net, followed by `Relu` nonlinearities. Add two layers with:\n", "* `output_channels=16, filter_size=(4,4), strides=(4,4)`\n", "* `output_channels=32, filter_size=(3,3), strides=(1,1)`\n", "\n", "C.3. Next, we want to attach two dense layers. To be able to do that, we take the output of the last conv layer which has the shape `(N, C, H, W)`, and flatten it to a 1-dimensional array of size `(N, C*H*W)`. Then, we can stack the dense layers, followed by `Relu` nonlinearities each:\n", "* `256` hidden neurals\n", "* `10` output neurons, corresponding to the 10 degit categories.\n", "\n", "C.4. Finally, we add the `LogSoftMax` layer. \n", "\n", "C.5. Play with the output of the CNN on the toy dataset to convince yourself you have implemented everything properly. \n", "\n", "\n", "The rest of the code we constructed above can be applied without further modification. Why is that so? Do you appreciated now the usefulness of Deep ML packages?" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "output shape of the model is (-1, 10).\n", "\n", "actual output shape is: (3, 10)\n", "log(softmax) values: [-2.264744 -2.4081538 -2.2366066 -2.302778 -2.4543526 -2.3551548\n", " -2.2510576 -2.1797056 -2.2681572 -2.336163 ]\n", "conservation of probability [0.9999999 1. 1.0000001]\n" ] } ], "source": [ "### Convolutional Neural network\n", "\n", "from jax.experimental.stax import GeneralConv, Flatten # neural network layers\n", "\n", "\n", "# cast data into 2D image format\n", "train_images = train_images.reshape(-1,1,28,28) # -1: number of data points, 1: input channels, (28,28) = (height, width) dimensions of image\n", "test_images = test_images.reshape(-1,1,28,28)\n", "\n", "# conv net convention\n", "dim_nums=('NCHW', 'OIHW', 'NCHW') # default for (input, filters, output)\n", "\n", "# define functions which initialize the parameters and evaluate the model\n", "initialize_params, predict = stax.serial( \n", " ### convolutional NN (CNN)\n", " GeneralConv(dim_nums, 16, (4,4), strides=(4,4) ), # 16 output channels, (4,4) filter\n", " Relu,\n", " GeneralConv(dim_nums, 32, (3,3), strides=(1,1) ), # 32 output channels, (3,3) filter\n", " Relu,\n", " Flatten, # flatten output\n", " Dense(256), # 256 hidden neurons\n", " Relu,\n", " Dense(10), # 10 output neurons\n", " LogSoftmax # NB: computes the log-probability\n", " )\n", "\n", "# initialize the model parameters\n", "output_shape, inital_params = initialize_params(rng, (-1, 1, 28, 28)) # conv layer, 1 input channel, 28x28 pixes in each image\n", "\n", "print('\\noutput shape of the model is {}.\\n'.format(output_shape))\n", "\n", "# check how network works on 3 examples\n", "predictions = predict(inital_params, test_images[0:3])\n", "\n", "# print shape of output\n", "print(\"actual output shape is:\", predictions.shape)\n", "\n", "# check if probability is conserved\n", "print('log(softmax) values:', predictions[0])\n", "print('conservation of probability', np.sum(jnp.exp(predictions), axis=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Further questions\n", "\n", "* do you see an advantage of JAX when compared to the other ML packages: TensorFlow or PyTorch? \n", "* compare the weights plots for the SoftMax regression in the presence of an L2 regularizer added to the cost function. You can use a reguarization strength of `lmbda=0.001`; what happens to the score if you increase/decrease this number? What if you implement an L1 regularizatin instead?\n", "* try and look for the patterns of the imprinted digits in the output weights of the DNN and CNN layers. Explain your findings, and compare them to the SoftMax regression." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "RL_class", "language": "python", "name": "rl_class" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" }, "latex_metadata": { "affiliation": "Faculty of Physics, Sofia University, 5 James Bourchier Blvd., 1164 Sofia, Bulgaria", "author": "Marin Bukov", "title": "Reinforcement Learning Course: WiSe 2020/21" } }, "nbformat": 4, "nbformat_minor": 4 }