RNN系列
00 min
2024-9-28
2024-9-28
type
status
date
slug
summary
tags
category
icon
password

RNN系列

4.1 RNN

notion image
计算公式如下:

4.2 LSTM

 
 
LSTM,GRU为什么可以缓解梯度消失问题?
本文主要参考李宏毅老师的视频介绍RNN相关知识,主要包括两个部分: 分别介绍Navie RNN,LSTM,GRU的结构 对比这三者的优缺点 结构图: 计算公式: 依赖每一个时刻的隐状态产生当前的输出,具体计算方式根据自己任务来定。 为什么naive RNN用tanh激活函数而不是relu? 这部分引用来自苏建林的文章: 通过对梯度的分析,知道梯度消失和梯度爆炸主要取决于 ,如果激活函数是tanh时, 的范围无法确定,因此 可能大于1,也可能小于1,梯度爆炸/梯度消失的风险是存在的。有趣的是。如果 很大,那么 就会很小, 反而更小。事实上,可以严格证明:如果固定 ,那么 是作为 的函数是有界的。也就是说无论 等于什么,它都不超过一个固定的常数。 使用tanh激活,梯度 是有界的,虽然未必是1,但一个有界的量不超过1的概率总高于无界的量。因此梯度爆炸的风险更低。相比之下,如果使用relu激活的话,他的正半轴导数恒为1,此时 是无界的,梯度爆炸的风险更高。 使用tanh只是缓解梯度爆炸的风险,使用tanh依然有梯度爆炸的可能性。处理梯度爆炸最根本的方法是梯度裁剪。 结构图: 计算公式: LSTM需注意以下两点: 除了包含隐状态h,也引入了记忆细胞C。 参数量大约是naive rnn的四倍(三个门控和一个候选记忆细胞)。 结构图: 计算公式: GRU需注意的两点: 舍弃LSTM中的记忆细胞单元C,只包含隐状态h。 参数量相当于naive rnn的三倍。 (1)RNN的梯度消失问题导致不能"长期依赖" RNN中的梯度消失不是指损失对参数的总梯度消失了,而是RNN中对较远时间步的梯度消失了。RNN中反向传播使用的是back propagation through time(BPTT)方法,损失loss对参数W的梯度等于loss在各时间步对w求导之和。用公式表示就是: 上式中 计算较复杂,根据复合函数求导法则连续求导得。 是当前隐状态对上一隐状态求偏导。 假设某一时间步j距离t时间步相差了(t-j)时刻。那么 如果t-j很大,也就是j距离t时间步很远, 时,会产生梯度消失问题。而当t-j很小时,也就是j是t的短期依赖,则不存在梯度消失/梯度爆炸的问题。一般会使用梯度裁剪解决梯度爆炸问题。所以主要分析梯度消失问题。 loss对时间步j的梯度值反映了时间步j对最终输出 的影响程度。j对最终输出 的影响程度越大,loss对时间步j的梯度值也就越大。如果loss对时间步j的梯度值趋于0,说明j对最终输出 没影响,就是常说的长期遗忘问题了。 综上:距离时间步t较远的j的梯度会消失,j对最终输出 没影响。就是RNN中不能长期依赖问题。 (2)LSTM如何解决梯度消失 LSTM设计的初衷就是让当前记忆单元对上一记忆单元的偏导为常数。如在1997年最初版本的LSTM,记忆细胞更新公式为: 后来为了避免记忆细胞无限增长,引入了"遗忘门"。更新公式为: 上文说过,梯度消失的主要原因在于递归导数 ,将更新公式改为加法公式,其导数具有更好的性质。 LSTM引入了记忆细胞,递归导数是 ,其中 都是 的函数。求偏导时,这四项都需要求。 对上式继续求解得到: 这个递归梯度公式与原本RNN的递归梯度有很大不同,naive RNN中 。在所有时间步t中, 的值要么始终大于1,要么始终在[0,1]范围内。这是导致梯度爆炸/梯度消失的罪魁祸首。但在LSTM中 在不同时间步可以采用大于1的值,也可以在下个时间步使用小于[0,1]区间的值。因此,时间步t扩展到无限时,数学上不能保证递归梯度收敛于0或无穷大。如果梯度开始收敛于0,那么可以设置 (以及其他门控)取值较高。来使得 接近于1,从而防止梯度消失太快。当前时刻梯度值中包含的 都是网络通过学习设置的。因此,这种门控的方式,让网络学会如何设置门控数值,来决定何时让梯度消失,何时保持梯度。 GRU的参数量少,减少过拟合的风险 LSTM的参数量是Navie RNN的4倍(看公式),参数量过多就会存在过拟合的风险,GRU只使用两个门控开关,达到了和LSTM接近的结果。其参数量是Navie RNN的三倍
LSTM,GRU为什么可以缓解梯度消失问题?
notion image
计算公式:
注意点:
1.除了包含隐状态h,也引入了记忆细胞C。
2.LSTM的参数量为RNN的四倍(三个门+一个记忆单元)。

4.3 GRU

 
notion image
计算公式如下:
注意点:
1.舍弃LSTM中的记忆细胞单元C,只包含隐状态h。
2.参数量相当于naive rnn的三倍。
 
上一篇
CNN系列
下一篇
Attention系列