{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Descent and its Generalizations\n", "\n", "\n", "## Learning Goal\n", "\n", "The goal of this notebook is to gain intuition for various gradient descent methods by visualizing and applying these methods to some simple two-dimensional surfaces. Methods studied include ordinary gradient descent, gradient descent with momentum, NAG, RMSProp, and ADAM. This notebook follows Notebook 2 and Section IV from the [ML Review](http://physics.bu.edu/~pankajm/MLnotebooks.html) by Mehta et al.\n", "\n", "\n", "## Overview\n", "\n", "In this notebook, we will visualize what different gradient descent methods are doing using some simple surfaces. From the onset, we emphasize that doing gradient descent on the surfaces is different from performing gradient descent on a loss function in Machine Learning (ML). The reason is that in ML not only do we want to find good minima, we want to find good minima that generalize well to new data. Despite this crucial difference, we can still build intuition about gradient descent methods by applying them to simple surfaces (for a useful blog post, see [here](http://ruder.io/optimizing-gradient-descent/)).\n", "\n", "## Surfaces\n", "\n", "We will consider three simple surfaces: \n", "\n", "* a quadratic minimum of the form \n", "\n", " $$z(x,y)=ax^2+by^2,$$ \n", "\n", "* a saddle-point of the form \n", "\n", " $$z(x,y)=ax^2-by^2,$$ \n", "\n", "* and [Beale's Function](https://en.wikipedia.org/wiki/Test_functions_for_optimization):\n", "\n", " $$z(x,y) = (1.5-x+xy)^2+(2.25-x+xy^2)^2+(2.625-x+xy^3)^2.$$\n", "\n", "Additionally, you may explore\n", "\n", "* [Rosenbrock's Function](https://en.wikipedia.org/wiki/Test_functions_for_optimization) which has a global minimum at (x,y) = (1,1):\n", "\n", "$$z(x,y) = (1-x)^2 + 100(y-x^2)^2,$$\n", "\n", "* and [Himmelblau's Function](https://en.wikipedia.org/wiki/Test_functions_for_optimization) which has four degenerate minima and a local maximum between them:\n", "\n", "$$z(x,y) = (x^2+y-11)^2 + (x+y^2-7)^2,$$\n", "\n", "The last three are non-convex functions often used to test optimization problems. These surfaces can be plotted using the cells below. \n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "scrolled": true }, "outputs": [], "source": [ "#This cell sets up basic plotting functions we will use to visualize the gradient descent routines.\n", "\n", "#Make plots interactive\n", "%matplotlib notebook\n", "\n", "#Make plots static\n", "#%matplotlib inline\n", "\n", "#Make 3D plots\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import matplotlib.pyplot as plt\n", "from matplotlib import cm\n", "from IPython.display import HTML\n", "from matplotlib.colors import LogNorm\n", "\n", "#Import Numpy\n", "import numpy as np\n", "\n", "#Define function for plotting \n", "\n", "def plot_surface(x, y, z, azim=-60, elev=40, dist=10, cmap=\"jet\"):\n", "\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111, projection='3d')\n", " plot_args = {'rstride': 1, 'cstride': 1, 'cmap':cmap,\n", " 'linewidth': 20, 'antialiased': True,\n", " 'vmin': -2, 'vmax': 2}\n", " ax.plot_surface(x, y, z, **plot_args)\n", " ax.view_init(azim=azim, elev=elev)\n", " ax.dist=dist\n", " ax.set_xlim(-1, 1)\n", " ax.set_ylim(-1, 1)\n", " ax.set_zlim(-2, 2)\n", " \n", " plt.xticks([-1, -0.5, 0, 0.5, 1], [\"-1\", \"-1/2\", \"0\", \"1/2\", \"1\"])\n", " plt.yticks([-1, -0.5, 0, 0.5, 1], [\"-1\", \"-1/2\", \"0\", \"1/2\", \"1\"])\n", " ax.set_zticks([-2, -1, 0, 1, 2])\n", " ax.set_zticklabels([\"-2\", \"-1\", \"0\", \"1\", \"2\"])\n", " \n", " ax.set_xlabel(\"x\", fontsize=18)\n", " ax.set_ylabel(\"y\", fontsize=18)\n", " ax.set_zlabel(\"z\", fontsize=18)\n", " return fig, ax;\n", "\n", "\n", "def overlay_trajectory_quiver(ax,obj_func,trajectory, color='k'):\n", "\n", " xs=trajectory[:,0]\n", " ys=trajectory[:,1]\n", " zs=obj_func(xs,ys)\n", " ax.quiver(xs[:-1], ys[:-1], zs[:-1], xs[1:]-xs[:-1], ys[1:]-ys[:-1],zs[1:]-zs[:-1],color=color,arrow_length_ratio=0.3)\n", " \n", " return ax;\n", "\n", "def overlay_trajectory(ax,obj_func,trajectory,label,color='k'):\n", " xs=trajectory[:,0]\n", " ys=trajectory[:,1]\n", " zs=obj_func(xs,ys)\n", " ax.plot(xs,ys,zs, color, label=label)\n", " \n", " return ax;\n", "\n", "\n", "def overlay_trajectory_contour(ax,trajectory, label,color='k',lw=2, plot_marker=False):\n", " xs=trajectory[:,0]\n", " ys=trajectory[:,1]\n", " ax.plot(xs,ys, color, label=label,lw=lw)\n", " if plot_marker:\n", " ax.plot(xs[-1],ys[-1], color+'>', markersize=10)\n", " return ax;" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": false }, "outputs": [], "source": [ "#DEFINE SURFACES WE WILL WORK WITH\n", "\n", "#Define monkey saddle and gradient\n", "def monkey_saddle(x,y):\n", " return x**3 - 3*x*y**2\n", "\n", "def grad_monkey_saddle(params):\n", " x=params[0]\n", " y=params[1]\n", " grad_x= 3*x**2-3*y**2\n", " grad_y= -6*x*y\n", " return [grad_x,grad_y]\n", "\n", "#Define saddle surface\n", "\n", "def saddle_surface(x,y,a=1,b=1):\n", " return a*x**2-b*y**2\n", "\n", "def grad_saddle_surface(params,a=1,b=1):\n", " x=params[0]\n", " y=params[1]\n", " grad_x= a*x\n", " grad_y= -1*b*y\n", " return [grad_x,grad_y]\n", "\n", "\n", "# Define minima_surface\n", "\n", "def minima_surface(x,y,a=1,b=1):\n", " return a*x**2+b*y**2-1\n", "\n", "def grad_minima_surface(params,a=1,b=1):\n", " x=params[0]\n", " y=params[1]\n", " grad_x= 2*a*x\n", " grad_y= 2*b*y\n", " return [grad_x,grad_y]\n", "\n", "\n", "def beales_function(x,y):\n", " return (1.5-x+x*y)**2 + (2.25-x+x*y**2)**2 + (2.625-x+x*y**3)**2\n", " \n", "\n", "def grad_beales_function(params):\n", " x=params[0]\n", " y=params[1]\n", " grad_x=2*(1.5-x+x*y)*(-1+y)+2*(2.25-x+x*y**2)*(-1+y**2)+2*(2.625-x+x*y**3)*(-1+y**3)\n", " grad_y=2*(1.5-x+x*y)*x+4*(2.25-x+x*y**2)*x*y+6*(2.625-x+x*y**3)*x*y**2\n", " return [grad_x,grad_y]\n", "\n", "def contour_beales_function():\n", " #plot beales function\n", " x, y = np.meshgrid(np.arange(-4.5, 4.5, 0.1), np.arange(-4.5, 4.5, 0.1))\n", " fig, ax = plt.subplots(figsize=(10, 6))\n", " z=beales_function(x,y)\n", " cax = ax.contour(x, y, z, levels=np.logspace(0, 5, 35), norm=LogNorm(), cmap=\"RdYlBu_r\")\n", " ax.plot(3,0.5, 'r*', markersize=18)\n", "\n", " ax.set_xlabel('$x$')\n", " ax.set_ylabel('$y$')\n", "\n", " ax.set_xlim((-4.5, 4.5))\n", " ax.set_ylim((-4.5, 4.5))\n", " \n", " return fig,ax\n", " \n", " \n", " \n", " \n", "#Make plots of surfaces\n", "plt.close() # closes previous plots\n", "x, y = np.mgrid[-1:1:31j, -1:1:31j]" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "/* Put everything inside the global mpl namespace */\n", "window.mpl = {};\n", "\n", "\n", "mpl.get_websocket_type = function() {\n", " if (typeof(WebSocket) !== 'undefined') {\n", " return WebSocket;\n", " } else if (typeof(MozWebSocket) !== 'undefined') {\n", " return MozWebSocket;\n", " } else {\n", " alert('Your browser does not have WebSocket support.' +\n", " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", " 'Firefox 4 and 5 are also supported but you ' +\n", " 'have to enable WebSockets in about:config.');\n", " };\n", "}\n", "\n", "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", " this.id = figure_id;\n", "\n", " this.ws = websocket;\n", "\n", " this.supports_binary = (this.ws.binaryType != undefined);\n", "\n", " if (!this.supports_binary) {\n", " var warnings = document.getElementById(\"mpl-warnings\");\n", " if (warnings) {\n", " warnings.style.display = 'block';\n", " warnings.textContent = (\n", " \"This browser does not support binary websocket messages. \" +\n", " \"Performance may be slow.\");\n", " }\n", " }\n", "\n", " this.imageObj = new Image();\n", "\n", " this.context = undefined;\n", " this.message = undefined;\n", " this.canvas = undefined;\n", " this.rubberband_canvas = undefined;\n", " this.rubberband_context = undefined;\n", " this.format_dropdown = undefined;\n", "\n", " this.image_mode = 'full';\n", "\n", " this.root = $('
');\n", " this._root_extra_style(this.root)\n", " this.root.attr('style', 'display: inline-block');\n", "\n", " $(parent_element).append(this.root);\n", "\n", " this._init_header(this);\n", " this._init_canvas(this);\n", " this._init_toolbar(this);\n", "\n", " var fig = this;\n", "\n", " this.waiting = false;\n", "\n", " this.ws.onopen = function () {\n", " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", " fig.send_message(\"send_image_mode\", {});\n", " if (mpl.ratio != 1) {\n", " fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n", " }\n", " fig.send_message(\"refresh\", {});\n", " }\n", "\n", " this.imageObj.onload = function() {\n", " if (fig.image_mode == 'full') {\n", " // Full images could contain transparency (where diff images\n", " // almost always do), so we need to clear the canvas so that\n", " // there is no ghosting.\n", " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", " }\n", " fig.context.drawImage(fig.imageObj, 0, 0);\n", " };\n", "\n", " this.imageObj.onunload = function() {\n", " fig.ws.close();\n", " }\n", "\n", " this.ws.onmessage = this._make_on_message_function(this);\n", "\n", " this.ondownload = ondownload;\n", "}\n", "\n", "mpl.figure.prototype._init_header = function() {\n", " var titlebar = $(\n", " '');\n", " var titletext = $(\n", " '');\n", " titlebar.append(titletext)\n", " this.root.append(titlebar);\n", " this.header = titletext[0];\n", "}\n", "\n", "\n", "\n", "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "\n", "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", "\n", "}\n", "\n", "mpl.figure.prototype._init_canvas = function() {\n", " var fig = this;\n", "\n", " var canvas_div = $('');\n", "\n", " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", "\n", " function canvas_keyboard_event(event) {\n", " return fig.key_event(event, event['data']);\n", " }\n", "\n", " canvas_div.keydown('key_press', canvas_keyboard_event);\n", " canvas_div.keyup('key_release', canvas_keyboard_event);\n", " this.canvas_div = canvas_div\n", " this._canvas_extra_style(canvas_div)\n", " this.root.append(canvas_div);\n", "\n", " var canvas = $('');\n", " canvas.addClass('mpl-canvas');\n", " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", "\n", " this.canvas = canvas[0];\n", " this.context = canvas[0].getContext(\"2d\");\n", "\n", " var backingStore = this.context.backingStorePixelRatio ||\n", "\tthis.context.webkitBackingStorePixelRatio ||\n", "\tthis.context.mozBackingStorePixelRatio ||\n", "\tthis.context.msBackingStorePixelRatio ||\n", "\tthis.context.oBackingStorePixelRatio ||\n", "\tthis.context.backingStorePixelRatio || 1;\n", "\n", " mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n", "\n", " var rubberband = $('');\n", " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", "\n", " var pass_mouse_events = true;\n", "\n", " canvas_div.resizable({\n", " start: function(event, ui) {\n", " pass_mouse_events = false;\n", " },\n", " resize: function(event, ui) {\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " stop: function(event, ui) {\n", " pass_mouse_events = true;\n", " fig.request_resize(ui.size.width, ui.size.height);\n", " },\n", " });\n", "\n", " function mouse_event_fn(event) {\n", " if (pass_mouse_events)\n", " return fig.mouse_event(event, event['data']);\n", " }\n", "\n", " rubberband.mousedown('button_press', mouse_event_fn);\n", " rubberband.mouseup('button_release', mouse_event_fn);\n", " // Throttle sequential mouse events to 1 every 20ms.\n", " rubberband.mousemove('motion_notify', mouse_event_fn);\n", "\n", " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", "\n", " canvas_div.on(\"wheel\", function (event) {\n", " event = event.originalEvent;\n", " event['data'] = 'scroll'\n", " if (event.deltaY < 0) {\n", " event.step = 1;\n", " } else {\n", " event.step = -1;\n", " }\n", " mouse_event_fn(event);\n", " });\n", "\n", " canvas_div.append(canvas);\n", " canvas_div.append(rubberband);\n", "\n", " this.rubberband = rubberband;\n", " this.rubberband_canvas = rubberband[0];\n", " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", " this.rubberband_context.strokeStyle = \"#000000\";\n", "\n", " this._resize_canvas = function(width, height) {\n", " // Keep the size of the canvas, canvas container, and rubber band\n", " // canvas in synch.\n", " canvas_div.css('width', width)\n", " canvas_div.css('height', height)\n", "\n", " canvas.attr('width', width * mpl.ratio);\n", " canvas.attr('height', height * mpl.ratio);\n", " canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n", "\n", " rubberband.attr('width', width);\n", " rubberband.attr('height', height);\n", " }\n", "\n", " // Set the figure to an initial 600x600px, this will subsequently be updated\n", " // upon first draw.\n", " this._resize_canvas(600, 600);\n", "\n", " // Disable right mouse context menu.\n", " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", " return false;\n", " });\n", "\n", " function set_focus () {\n", " canvas.focus();\n", " canvas_div.focus();\n", " }\n", "\n", " window.setTimeout(set_focus, 100);\n", "}\n", "\n", "mpl.figure.prototype._init_toolbar = function() {\n", " var fig = this;\n", "\n", " var nav_element = $('')\n", " nav_element.attr('style', 'width: 100%');\n", " this.root.append(nav_element);\n", "\n", " // Define a callback function for later on.\n", " function toolbar_event(event) {\n", " return fig.toolbar_button_onclick(event['data']);\n", " }\n", " function toolbar_mouse_event(event) {\n", " return fig.toolbar_button_onmouseover(event['data']);\n", " }\n", "\n", " for(var toolbar_ind in mpl.toolbar_items) {\n", " var name = mpl.toolbar_items[toolbar_ind][0];\n", " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", " var image = mpl.toolbar_items[toolbar_ind][2];\n", " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", "\n", " if (!name) {\n", " // put a spacer in here.\n", " continue;\n", " }\n", " var button = $('');\n", " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n", " 'ui-button-icon-only');\n", " button.attr('role', 'button');\n", " button.attr('aria-disabled', 'false');\n", " button.click(method_name, toolbar_event);\n", " button.mouseover(tooltip, toolbar_mouse_event);\n", "\n", " var icon_img = $('');\n", " icon_img.addClass('ui-button-icon-primary ui-icon');\n", " icon_img.addClass(image);\n", " icon_img.addClass('ui-corner-all');\n", "\n", " var tooltip_span = $('');\n", " tooltip_span.addClass('ui-button-text');\n", " tooltip_span.html(tooltip);\n", "\n", " button.append(icon_img);\n", " button.append(tooltip_span);\n", "\n", " nav_element.append(button);\n", " }\n", "\n", " var fmt_picker_span = $('');\n", "\n", " var fmt_picker = $('');\n", " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n", " fmt_picker_span.append(fmt_picker);\n", " nav_element.append(fmt_picker_span);\n", " this.format_dropdown = fmt_picker[0];\n", "\n", " for (var ind in mpl.extensions) {\n", " var fmt = mpl.extensions[ind];\n", " var option = $(\n", " '', {selected: fmt === mpl.default_extension}).html(fmt);\n", " fmt_picker.append(option)\n", " }\n", "\n", " // Add hover states to the ui-buttons\n", " $( \".ui-button\" ).hover(\n", " function() { $(this).addClass(\"ui-state-hover\");},\n", " function() { $(this).removeClass(\"ui-state-hover\");}\n", " );\n", "\n", " var status_bar = $('