Like Stochastic Gradient Descent, Adam is a optimization algorithm to locate the minium of a function. It can be understood as a combination of SGD with momentum and RMSProp. Momentum uses a moving average instead of the previous weight to perform weight update. RMSProp assign learning rate to individual parameters based on a moving rage of squared weights. Adams combines the two.
Similar to the one performed in SGD, this entry will follow a similar pattern to implement these optimization algorithms by approximating a quadratic form.
We prep the data first:
import torch
from functools import partial
# mean square error
def mse(preds, acts):
return ((preds - acts) ** 2).mean()
def quad(a, b, c, x):
return a * x**2 + b * x + c
def mk_quad(a, b, c):
return partial(quad, a, b, c)
# target model
f = mk_quad(2, 3, 4)
f(2)
# assume some data points
x = torch.linspace(-2, 2, 20)[:, None]
torch.manual_seed(42)
# Generate a tensor of random numbers with the same shape as f(x)
# torch.rand_like(f(x)) generates random numbers between 0 and 1
# with the same shape as f(x). We scale and shift it to the desired range.
random_numbers = torch.rand_like(f(x)) * 10 - 5
# dataset
y = f(x) + random_numbers
# loss function
def quad_mse(params):
f = mk_quad(*params)
return mse(f(x), y)
# initial params
params = torch.tensor([4, 5.0, 7.0])
params.requires_grad_()
loss = quad_mse(params)
loss
loss.backward()
params.grad
tensor([20.1243, 6.9424, 9.1506])
Momentum
Momentum allow the weight update to gain momentum, a moving inertia based on moving averages. It allows the weight update to overcome small variations. We introduce a parameter to denote how much momentum to use. If is 0, then the weight is not affected by the moving average. The algorithm for momentum is the following:
weight.avg = beta * weight.avg + (1-beta) * weight.grad
new_weight = weight - lr * weight.avg
beta = 0.1
lr = 0.05
params = torch.tensor([4, 5.0, 7.0])
params.requires_grad_()
weight_avg = torch.zeros(params.shape)
for _ in range(10):
loss = quad_mse(params)
print("loss", loss.item())
loss.backward()
weight_avg = beta * weight_avg + (1 - beta) * params.grad.data
params.data -= lr * weight_avg
params.grad = None
loss 41.75225067138672
loss 22.198631286621094
loss 13.591499328613281
loss 10.369937896728516
loss 8.952920913696289
loss 8.165651321411133
loss 7.653504848480225
loss 7.295266628265381
loss 7.037405490875244
loss 6.849569797515869
RMSProp
RMSProp allow each parameter gets its own specific learning rate controlled by a global learning rate. We determine the tendency of the learning rate using a moving average of squared weights. Instead of simple average, we use squared average because we want to capture the magnitude of the change. We introduce , which serves the same purpose as in momentum. The is for numerical stability.
For RMSProp, we implement the following algorithm:
w.square_avg = alpha * w.square_avg + (1-alpha) * (w.grad ** 2)
new_w = w - lr * w.grad / math.sqrt(w.square_avg + eps)
Here, the learning rate is lr / math.sqrt(w.square_avg + eps), which is a tensor the same size as parameters.
alpha = 0.99
eps = 1e-8
lr = 0.05
params_rp = torch.tensor([4, 5.0, 7.0])
params_rp.requires_grad_()
sqr_avg = torch.zeros(params_rp.shape)
for _ in range(10):
loss_rp = quad_mse(params_rp)
print("loss", loss_rp.item())
loss_rp.backward()
sqr_avg = alpha * sqr_avg + (1 - alpha) * (params_rp.grad.data**2)
params_rp.data -= lr * params_rp.grad.data / torch.sqrt(sqr_avg + eps)
params_rp.grad = None
loss 41.75225067138672
loss 25.9729061126709
loss 18.755748748779297
loss 14.587450981140137
loss 11.957747459411621
loss 10.221823692321777
loss 9.04482650756836
loss 8.232916831970215
loss 7.666159152984619
loss 7.267086982727051
Adam
We combine momentum and RMSProp and get Adam. Unlike momentum however, Adam uses unbiased average. For adam, we implement the following algorithm:
w.avg = beta1 * w.avg + (1-beta1) * w.grad
unbias_avg = w.avg / (1 - (beta1**(i+1)))
w.sqr_avg = beta2 * w.sqr_avg + (1-beta2) * (w.grad ** 2)
new_w = w - lr * unbias_avg / sqrt(w.sqr_avg + eps)
beta = 0.1
alpha = 0.99
eps = 1e-8
lr = 0.05
wd = 0.01 # decoupled weight decay
params_ad = torch.tensor([4, 5.0, 7.0])
params_ad.requires_grad_()
sqr_avg = torch.zeros(params_ad.shape)
weight_avg = torch.zeros(params_ad.shape)
for i in range(10):
loss_rp = quad_mse(params_ad)
print("loss", loss_rp.item())
loss_rp.backward()
weight_avg = beta * weight_avg + (1 - beta) * params_ad.grad.data
unbiased_avg = weight_avg / (1 - (beta ** (i + 1)))
sqr_avg = alpha * sqr_avg + (1 - alpha) * (params_ad.grad.data**2)
params_ad.data -= lr * unbiased_avg / torch.sqrt(sqr_avg + eps)
params_ad.data -= lr * params_ad.data * wd
params_ad.grad = None
loss 41.75225067138672
loss 25.913122177124023
loss 18.47113609313965
loss 14.244725227355957
loss 11.626660346984863
loss 9.926828384399414
loss 8.792333602905273
loss 8.021905899047852
loss 7.492585182189941
loss 7.125898838043213