首先分析一下adam的思想:

  • 获取梯度/导数/一阶动量的滑动平均——m1/滑动平均后的梯度g
  • 获取二阶导数/二阶动量的滑动平均——m2

其中,β1和β2就是控制一阶动量和二阶动量滑动平均的参数。通过滑动平均,可以使得m1和m2都变成前面所有时间步结果的一个调和平均
之后:

  • 根据β1和β2,以及当前的时间步t,计算一个系数,调整学习率,得到L_t
  • 梯度下降,公式变成了: 参数 = 参数 - L_t * m1 / (sqrt(m2) + eps)

最后梯度下降的公式,实际上,看着花里胡哨的,但最终,m1之外的其他部分都可以看成是对学习率的调整,本质上只是让学习率更好而已。

而在原始论文及其最初的衍生实现中,还将L2正则加入到了这个优化器中,具体表现为,在每一步t开始计算优化前,将当前时间步t梯度g_t变成:

1
g_t = g_t + weight_decay * L2项

然后后面就使用g_t进行m1、m2等的计算。
但是,这样的问题是,L2正则直接加在梯度g_t上面,后面的一阶动量和二阶动量都使用了g_t,最后又有 **m1 / (sqrt(m2) + eps)**这个东西,L2正则项被大大削减了。因而,L2正则就不生效了。

后来,这个问题被发现,adamW应运而生,同时adam的代码也被更改了,不再预先把L2加到梯度上。

adamW相对于adam的不同就是,其真正实现了weight decay/L2 Norm。
adam和adamW的大致算法步骤与公式参见下面这个博客:
【深度学习基础】第十九课:Adam优化算法
从公式中可以看出,只是在最后计算梯度更新的时候加了权重衰减。
实际实现中,也很简单,adam和adamW的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class my_opt():
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2):
self.params = params
self.lr = lr
self.b1 = betas[0]
self.b2 = betas[1]
self.eps = eps
self.wd = weight_decay
self.m = 0
self.v = 0
self.b1t = 1.0
self.b2t = 1.0

# 模拟AdamW的step
def stepW(self):
for name, param in self.params.named_parameters():
if param.grad is None:
continue
g = param.grad
self.m = self.b1 * self.m + (1 - self.b1) * g
self.v = self.b2 * self.v + (1 - self.b2) * g * g
self.b1t *= self.b1
self.b2t *= self.b2
m = self.m / (1 - self.b1t)
v = self.v / (1 - self.b2t)
n = 1.0
param.data -= n * (self.lr * (m / (v.sqrt() + self.eps) + self.wd * param.data))

# 模拟Adam的step
def step(self):
for name, param in self.params.named_parameters():
if param.grad is None:
continue
g = param.grad
self.m = self.b1 * self.m + (1 - self.b1) * g
self.v = self.b2 * self.v + (1 - self.b2) * g * g
self.b1t *= self.b1
self.b2t *= self.b2
m = self.m / (1 - self.b1t)
v = self.v / (1 - self.b2t)
n = 1.0
param.data -= n * (self.lr * m / (v.sqrt() + self.eps))

可以看到,代码里面adamW确实也只是在最后加了衰减。
那么,为什么加了这个参数本身的衰减,就是L2正则生效了呢?L2正则不是加在损失函数上面的吗?
这里,首先理解L2正则是什么,怎么加上去
那么,实际上,L2正则加在损失函数上,在反向传播的时候,经过公式推导,就是直接减去参数本身的形式。

L2正则反向传播公式推导或许可以参见:
L2 Normalization(L2归一化)反向传播推导