Weight decay, or L2 regularization, is a regularization method where an extra term proportional to the size of the weights is added to the loss function. The intuition behind weight decay is the following. We add a term with weight params to the loss function. As a result, the gradient of the loss function will be larger, so the parameters will descend more quickly, at a speed proportional to the size of the weights. This will move the parameters to smaller values.
Formally, we add an extra term, which is the sum of all weights squared times a weight decay factor, to the loss function. We write the definition below, where are weights and is the weight decay factor:
When weights are large, the term grows and leads to larger gradients and therefore quicker descend.
The original SGD becomes:
Note
In practice, we ignore the coefficient 2 in front of the weight decay term.
So the update rule becomes
where is the learning rate.
In the following example, we illustrate how weight decay works by showing how weight decay works for function , we show side by side how SGD steps with and without a weight decay factor with two optimizers. As we can see, the resulting weights with weight decay factor ended up smaller than the weights without weight decay. The difference is equal to . That is to say, in pytorch, weight decay adds to the standard SGD (the coefficient 2 is omitted).
import torch
w = torch.tensor([1.0], requires_grad=True)
opt_no_wd = torch.optim.SGD([w], lr=0.1, weight_decay=0.0)
loss = w * w
print("Initial weight:", w.item())
opt_no_wd.zero_grad()
loss.backward()
opt_no_wd.step()
print("After step without weight decay:", w.item())
# Reset weight
w = torch.tensor([1.0], requires_grad=True)
opt_wd = torch.optim.SGD([w], lr=0.1, weight_decay=0.1)
# Step with weight decay
opt_wd.zero_grad()
loss = w * w
loss.backward()
opt_wd.step()
print("After step with weight decay:", w.item())
Initial weight: 1.0
After step without weight decay: 0.800000011920929
After step with weight decay: 0.7900000214576721