%matplotlib inline
import ipywidgets as widgets
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML, display
# Total number of steps to draw in the diagram
num_steps = 32
####################################################
# Define the control grid
####################################################
# Common style used by all labels: right justify, fixed width.
style_label = widgets.Layout(display='flex', justify_content='flex-end', width='15em')
# Common style used by all checkboxes: set fixed width
style_checkbox = widgets.Layout(width='13em')
# Style for individual optimization checkbox, so that their text colors match
# the color used in the output display
custom_css = widgets.HTML("""<style>
.ctrl_regular > label > span {color:blue;}
.ctrl_momentum > label > span {color:red;}
.ctrl_nesterov > label > span {color:black;}
</style>""")
# The labels of the controls
label_show = widgets.Label('Show/Hide:', layout=style_label)
label_func_select = widgets.Label('Select function:', layout=style_label)
label_b = widgets.Label('$b : f(x,y)= x^2+b y^2$', layout=style_label)
label_learning_rate = widgets.Label('Learning rate', layout=style_label)
label_momentum_coeff = widgets.Label('Momentum coefficient', layout=style_label)
label_animate = widgets.Label('Click play to animate', layout=style_label)
label_nesterov_coeff = widgets.Label('Nesterov coefficient', layout=style_label)
# Checkboxes controlling whether to show or hide each optimiztion algorithm
ctrl_regular = widgets.Checkbox(value = True, description='Regular Grad. Descent', layout=style_checkbox, indent=False)
ctrl_regular.add_class('ctrl_regular')
ctrl_momentum = widgets.Checkbox(value = True, description='Momentum Grad. Descent', layout=style_checkbox, indent=False)
ctrl_momentum.add_class('ctrl_momentum')
ctrl_nesterov = widgets.Checkbox(value = True, description='Nesterov Accel.', layout=style_checkbox, indent=False)
ctrl_nesterov.add_class('ctrl_nesterov')
# Select the function
ctrl_func_select = widgets.RadioButtons(
options = [("f(x,y)=x^2+b*y^2",1), ("f(x,y)=x^6+b*y^4",2), ("f(x,y)=x^2+b*(x^2-y)^2",3)],
description=''
)
# Slider to control the function $f$
ctrl_b = widgets.FloatSlider(value = 0.15, min=0.1, max=1.0, step=0.05, description='')
# Slider to control the optimization algorithm parameters
ctrl_learning_rate = widgets.FloatSlider(value = 0.9, min=0.1, max=2.0, step=0.05, description='')
ctrl_momentum_coeff = widgets.FloatSlider(value = 0.4, min=0.1, max=2.0, step=0.05, description='')
ctrl_nesterov_coeff = widgets.FloatSlider(value = 0.1, min=0.01, max=1.0, step=0.02, description='')
# The animation control
ctrl_animate = widgets.Play(value=0, min=0, max=num_steps, step=1, interval=300, disabled=False)
# The output area
output = widgets.Output(layout=widgets.Layout(border='1px solid blue', width='16cm', height='16cm'))
# Arrange all controls and output together
all = widgets.VBox([
custom_css,
widgets.HBox([label_show, ctrl_regular, ctrl_momentum, ctrl_nesterov]),
widgets.HBox([label_func_select, ctrl_func_select]),
widgets.HBox([label_b, ctrl_b]),
widgets.HBox([label_learning_rate, ctrl_learning_rate]),
widgets.HBox([label_momentum_coeff, ctrl_momentum_coeff]),
widgets.HBox([label_nesterov_coeff, ctrl_nesterov_coeff]),
widgets.HBox([label_animate, ctrl_animate]),
output,
])
plt.ioff()
####################################################
# Define function f, the gradient, and the calculation
# of data points of various optimization algorithms.
####################################################
# Given x, y coordinates, returns $f(x,y)$.
def f1(x, y, b):
return x**2 + b*(y**2)
# Returns the gradient of $f(x,y)$ at the given pt.
def grad_f1(pt, b):
return np.array([2.*pt[0], 2.*b*pt[1]])
def f2(x, y, b):
return x**6 + b*(y**4)
def grad_f2(pt, b):
return np.array([6.*(pt[0]**5), 4.*b*(pt[1]**3)])
def f3(x, y, b):
return 0.2 * (x**2 + b*(x**2 - y)**2)
def grad_f3(pt, b):
x = pt[0]
y = pt[1]
return 0.2 * np.array([2*x + 4*b*x*(x**2 - y), 2*b*(y - x**2)])
# Given x, y coordinates, returns $f(x,y)$.
# Select the right function to use based on the radio button value.
def f(x, y, b):
if ctrl_func_select.value == 1:
return f1(x, y, b)
elif ctrl_func_select.value == 2:
return f2(x, y, b)
else:
return f3(x, y, b)
# Returns the gradient of $f(x,y)$ at the given pt.
# Select the right function to use based on the radio button value.
def grad_f(pt, b):
if ctrl_func_select.value == 1:
return grad_f1(pt, b)
elif ctrl_func_select.value == 2:
return grad_f2(pt, b)
else:
return grad_f3(pt, b)
# Computes data[i] using regular gradient descend.
def regular_grad_descend(data, i, b, learning_rate):
data[i] = data[i-1] - learning_rate * grad_f(data[i-1], b)
# Computes data[i] using momentum gradient descend.
def momentum_grad_descend(data, momentum_list, i, b, learning_rate, momentum_coeff):
momentum_list[i] = grad_f(data[i-1], b) + momentum_coeff * momentum_list[i-1]
data[i] = data[i-1] - learning_rate * momentum_list[i]
# Computes data[i] using Nesterov acceleration.
def nesterov_acceleration(data, i, b, learning_rate, momentum_coeff, nesterov_coeff):
if i>1 :
data[i] = data[i-1] + momentum_coeff * (data[i-1] - data[i-2]) \
- learning_rate * grad_f(data[i-1] + nesterov_coeff * (data[i-1] - data[i-2]), b)
else:
regular_grad_descend(data, i, b, learning_rate)
# Draw the contour of the given function
def draw_function_field():
x = np.linspace(-1.5, 1.5, 200)
y = np.linspace(-1.5, 1.5, 200)
x, y = np.meshgrid(x, y)
z = f(x, y, ctrl_b.value)
levels = np.linspace(0, 0.5, 20)
with output:
output.clear_output(wait=True)
plt.gcf().set_size_inches(9, 9)
plt.clf()
plt.cla()
plt.gca().contour(x, y, z, levels)
# Draw the path of the optimization algorithms.
def draw_optimization(_):
# First draw the contour of the function
draw_function_field()
# Set the starting point
start_point = [1., 1.]
if ctrl_func_select.value == 3:
start_point = [0.5, 1.]
# Compute the regular gradient descent data points
data_regular = np.array([start_point] * num_steps)
for i in range(1, num_steps):
regular_grad_descend(data_regular,
i,
ctrl_b.value,
ctrl_learning_rate.value)
# Compute the moment gradient descent data points
data_momentum = np.array([start_point] * num_steps)
momentum_list = np.array([[0., 0.]] * num_steps)
for i in range(1, num_steps):
momentum_grad_descend(data_momentum,
momentum_list,
i,
ctrl_b.value,
ctrl_learning_rate.value,
ctrl_momentum_coeff.value)
# Compute the Nesterov acceleration data points
data_nesterov = np.array([start_point] * num_steps)
for i in range(1, num_steps):
nesterov_acceleration(data_nesterov,
i,
ctrl_b.value,
ctrl_learning_rate.value,
ctrl_momentum_coeff.value,
ctrl_nesterov_coeff.value)
# Perform the drawing
animate_step = ctrl_animate.value
with output:
ax=plt.gca()
if ctrl_animate._playing:
# If animation is playing, only draw the last two steps
start = max(0, animate_step-3)
end = min(num_steps, animate_step)
else:
# Otherwise (animation is not playing), draw all steps
start = 0
end = num_steps
# Draw regular gradient descent
if ctrl_regular.value:
ax.plot(data_regular[start:end, 0],
data_regular[start:end, 1],
label="Regular Gradient Descent",
color='blue',
marker='o')
# Draw momentum gradient descent
if ctrl_momentum.value:
ax.plot(data_momentum[start:end, 0],
data_momentum[start:end, 1],
label="Momentum Gradient Descent",
color='red',
marker='o')
# Draw Nesterov acceleration
if ctrl_nesterov.value:
ax.plot(data_nesterov[start:end, 0],
data_nesterov[start:end, 1],
label="Nesterov Acceleration",
color='black',
marker='o')
ax.axis('scaled')
ax.legend(loc="upper left")
display(ax.figure)
def change_func(_):
if ctrl_func_select.value == 1:
label_b.value = '$b : f(x,y)= x^2+b y^2$'
ctrl_b.min = 0.1
ctrl_b.max = 1.0
ctrl_b.value = 0.15
ctrl_learning_rate.min = 0.1
ctrl_learning_rate.max = 2.0
ctrl_learning_rate.value = 0.9
ctrl_learning_rate.step = 0.05
ctrl_momentum_coeff.min = 0.1
ctrl_momentum_coeff.max = 2.0
ctrl_momentum_coeff.value = 0.4
ctrl_nesterov_coeff.value = 0.1
elif ctrl_func_select.value == 2:
label_b.value = '$b : f(x,y)= x^6+b y^4$'
ctrl_b.min = 0.1
ctrl_b.max = 1.0
ctrl_b.value = 0.15
ctrl_learning_rate.min = 0.01
ctrl_learning_rate.max = 0.21
ctrl_learning_rate.value = 0.03
ctrl_learning_rate.step = 0.01
ctrl_momentum_coeff.min = 0.5
ctrl_momentum_coeff.max = 1.0
ctrl_momentum_coeff.value = 0.85
ctrl_nesterov_coeff.value = 0.7
else:
label_b.value = '$b : f(x,y)= x^2+b (x^2-y)^2$'
ctrl_b.max = 20.
ctrl_b.min = 1.
ctrl_b.value = 10.
ctrl_learning_rate.min = 0.01
ctrl_learning_rate.max = 0.21
ctrl_learning_rate.value = 0.03
ctrl_learning_rate.step = 0.01
ctrl_momentum_coeff.min = 0.5
ctrl_momentum_coeff.max = 1.0
ctrl_momentum_coeff.value = 0.85
ctrl_nesterov_coeff.value = 0.7
# Make sure changes to any controls trigger re-draw
ctrl_regular.observe(draw_optimization, names='value')
ctrl_momentum.observe(draw_optimization, names='value')
ctrl_nesterov.observe(draw_optimization, names='value')
ctrl_b.observe(draw_optimization, names='value')
ctrl_learning_rate.observe(draw_optimization, names='value')
ctrl_momentum_coeff.observe(draw_optimization, names='value')
ctrl_nesterov_coeff.observe(draw_optimization, names='value')
ctrl_animate.observe(draw_optimization, names=['value','_playing'])
ctrl_func_select.observe(change_func, names='value')
# Show all controls and output
display(all)
# Perform initial drawing
draw_optimization(0)