Deep Learning Systems Lecture 20 Transformers and Attention
这里回顾dlsys第二十讲,本讲内容是Transformer和Attention。
课程主页:
大纲
- 时间序列建模的两种方法;
- Self-attention和Transformer;
- 超越时间序列的变形金刚;
时间序列建模的两种方法
时间序列预测
让我们回顾一下之前课程中的基本时间序列预测任务:
更根本的是,时间序列预测任务可以写成如下预测任务:
其中$y_t$只能依赖于$x_{1:t}$。
这样做有多种方法,可能涉及也可能不涉及RNN的隐状态表示。
RNN“隐状态”方法
我们已经看到了时间序列的RNN方法:保持“隐状态”$h_t$,它总结了到$t$时刻为止的所有信息。
- 优点:潜在的“无限”历史,紧凑的表示;
- 缺点:历史和当前时间之间的长“计算路径”⟹消失/爆炸梯度,难以学习;
“直接预测”方法
相比之下,也可以直接预测输出$y_t= f_{\theta}(x_{1:t})$。(只需要一个可以预测不同大小输入的函数)。
- 优点:通常可以使用更短的计算路径从过去映射到当前状态;
- 缺点:没有紧凑的状态表示,实际中的历史有限;
用于直接预测的CNN
指定函数$f_\theta$最直接的方法之一:(fully)卷积网络,又名时间卷积网络 (TCN)。主要约束是卷积是单向:$z_t^{(i+1)}$只能取决于$z_{t-k:t}^{(i+1)}$。许多成功的应用:例如用于语音生成的WaveNet。
CNN在密集预测方面的挑战
尽管它们很简单,但CNN在时间序列预测方面有一个明显的缺点:每个卷积的感受野通常相对较小 ⟹
需要深度网络来实际整合过去的信息。
潜在的解决方案:
- 增加内核大小:同时增加网络参数;
- 池化层:不太适合密集预测,例如我们想要预测所有$y_{1:T}$;
- Dilated卷积:“跳过”一些过去的状态/输入;
Self-attention和Transformer
深度学习中的“Attention”
深度网络中的“Attention”通常指的是对单个状态进行加权然后组合的任何机制:
最初在RNN中使用,当人们想要以一种比“仅仅”查看最后一个状态更通用的方式组合所有时间的潜在状态时。
Self-attention操作
Self-attention是一种特殊形式的attention机制。
给定输入$K, Q, V\in \mathbb R^{T\times d}$:
定义self-attention算子为:
Self-attention更多的细节
self-attention的性质:
- 关于$K, Q,V$矩阵的置换不变性;
- 允许所有时间的$k_t, q_t, v_t$相互影响;
- 计算成本为$O(T^2d)$(由于非线性应用于完整$T\times T$矩阵,因此无法轻易降低复杂度);
用于时间序列的Transformer
Transformer架构使用一系列注意力机制(和前馈层)来处理时间序列:
(在下一张幻灯片中详细描述)
所有时间步长(在实际中,在给定的时间片内)都是并行处理的,避免了像在 RNN中那样进行顺序处理的需要。
Transformer block
更细节的,Transformer block有如下形式:
有点复杂,但实际上只是自注意力,然后是一个线性层 + ReLU,带有一些任意的残差连接和归一化(并且经常调整它们的准确位置)。
应用于时间序列的Transformer
我们可以将Transformer block应用于时间序列的“直接”预测方法,而不是使用卷积块。
优点:
- 单层内的完整感受野(即可以立即使用过去的数据);
- 随着时间的推移混合不会增加参数数量(与卷积不同);
缺点:
- 所有输出都取决于所有输入(不好,例如,对于自回归任务);
- 无序;
Masked self-attention
为了解决“非causal”依赖性问题,我们可以mask softmax运算符以将零权重分配给任何“未来”时间步长:
其中:
请注意,尽管从技术上讲,这意味着我们可以“避免”在注意力矩阵中创建这些项,但在实际中,形成它们然后将它们屏蔽掉通常会更快。
位置编码
为了解决“位置不变性”,我们可以给输入添加位置编码,它将每个输入与其在序列中的位置相关联:
其中$\omega_i, i=1,\ldots, n$通常根据对数时间表选择。
超越时间序列的变形金刚
最近的工作已经观察到,transformer block非常强大,不仅仅是时间序列:
- Vision Transformers:将Transformer应用于图像(由一组补丁嵌入表示),对于大型数据集比CNN效果更好;
- Graph Transformers:在注意力矩阵中捕获图结构;
在所有情况下,主要挑战是:
- 如何表示数据使得$O(T^2)$操作可行;
- 如何形成位置编码;
- 如何形成mask矩阵;