首先分析一下adam的思想:
- 获取梯度/导数/一阶动量的滑动平均——m1/滑动平均后的梯度g
- 获取二阶导数/二阶动量的滑动平均——m2
其中,β1和β2就是控制一阶动量和二阶动量滑动平均的参数。通过滑动平均,可以使得m1和m2都变成前面所有时间步结果的一个调和平均
之后:
- 根据β1和β2,以及当前的时间步t,计算一个系数,调整学习率,得到L_t
- 梯度下降,公式变成了: 参数 = 参数 - L_t * m1 / (sqrt(m2) + eps)
最后梯度下降的公式,实际上,看着花里胡哨的,但最终,m1之外的其他部分都可以看成是对学习率的调整,本质上只是让学习率更好而已。
而在原始论文及其最初的衍生实现中,还将L2正则加入到了这个优化器中,具体表现为,在每一步t开始计算优化前,将当前时间步t梯度g_t变成:
1
| g_t = g_t + weight_decay * L2项
|
然后后面就使用g_t进行m1、m2等的计算。
但是,这样的问题是,L2正则直接加在梯度g_t上面,后面的一阶动量和二阶动量都使用了g_t,最后又有 **m1 / (sqrt(m2) + eps)**这个东西,L2正则项被大大削减了。因而,L2正则就不生效了。
后来,这个问题被发现,adamW应运而生,同时adam的代码也被更改了,不再预先把L2加到梯度上。
adamW相对于adam的不同就是,其真正实现了weight decay/L2 Norm。
adam和adamW的大致算法步骤与公式参见下面这个博客:
【深度学习基础】第十九课:Adam优化算法
从公式中可以看出,只是在最后计算梯度更新的时候加了权重衰减。
实际实现中,也很简单,adam和adamW的代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| class my_opt(): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2): self.params = params self.lr = lr self.b1 = betas[0] self.b2 = betas[1] self.eps = eps self.wd = weight_decay self.m = 0 self.v = 0 self.b1t = 1.0 self.b2t = 1.0
def stepW(self): for name, param in self.params.named_parameters(): if param.grad is None: continue g = param.grad self.m = self.b1 * self.m + (1 - self.b1) * g self.v = self.b2 * self.v + (1 - self.b2) * g * g self.b1t *= self.b1 self.b2t *= self.b2 m = self.m / (1 - self.b1t) v = self.v / (1 - self.b2t) n = 1.0 param.data -= n * (self.lr * (m / (v.sqrt() + self.eps) + self.wd * param.data))
def step(self): for name, param in self.params.named_parameters(): if param.grad is None: continue g = param.grad self.m = self.b1 * self.m + (1 - self.b1) * g self.v = self.b2 * self.v + (1 - self.b2) * g * g self.b1t *= self.b1 self.b2t *= self.b2 m = self.m / (1 - self.b1t) v = self.v / (1 - self.b2t) n = 1.0 param.data -= n * (self.lr * m / (v.sqrt() + self.eps))
|
可以看到,代码里面adamW确实也只是在最后加了衰减。
那么,为什么加了这个参数本身的衰减,就是L2正则生效了呢?L2正则不是加在损失函数上面的吗?
这里,首先理解L2正则是什么,怎么加上去。
那么,实际上,L2正则加在损失函数上,在反向传播的时候,经过公式推导,就是直接减去参数本身的形式。
L2正则反向传播公式推导或许可以参见:
L2 Normalization(L2归一化)反向传播推导