参考代码:

参考论文:

概述:

  • 待补充

符号参考:

动机

Softmax Attention的时间复杂度为$O(n^2d)$,其中$n$为序列长度,$d$为特征维度。有很多工作来改进这点,其中Performer和RFA都是通过对Softmax中Exp函数的近似来做到这点,下面会通过原理,实现的角度进行分析。

原理

Softmax Attention(单头情形,忽略缩放因子$\sqrt d$)的计算公式为:

Performer和RFA的思路都是构造映射$\phi$,使得:

其中${\mathbf w}$为服从分布$\mathcal D$的随机变量。

根据上式,可得:

写成矩阵形式,即:

该形式即Linear Attention的形式,所以时间复杂度为$O(nd^2)$,当序列够长时,即可降低时间复杂度。

那么后续问题就是找到$\phi$,Performer和RFA在这里给出了不同的方式。但是在介绍之前,先做一些准备。

准备工作

假设存在$f_i$满足如下性质:

那么根据期望的线性性,可得:

这等价于:

所以可以给出如下$\phi $:

有了这些准备工作,可以开始介绍Performer和RFA。

Performer

Performer给出的$f_i$形式如下:

RFA

Performer给出的$f_i$形式如下:

降低期望方差

因为Performer和RFA都是基于期望的方法, 所以一个重要问题是降低方差,根据Orthogonal Random Features提供的采样方法,可以利用如下方式得到$\mathbf w_i, i=1,\ldots, m$:

  • 采样$d$个$d$维标准正态分布$\mathbf g \in \mathbb R^{d}$,得到矩阵$\mathbf G\in \mathbb R^{d\times d}$;
  • 做$\mathbf {QR}$分解:$\mathbf G=\mathbf {QR}$;
  • 采样$d$个degree为$d$的卡方分布随机变量$s$,构造为对角阵$\mathbf S\in \mathbb R^{d\times d}$;
  • 得到最终结果:$\mathbf W = \mathbf S \mathbf Q\in \mathbb R^{d\times d}$;

其余细节

注意到Attention中$\sqrt d$,即我们近似的目标项为:

为了近似该项,注意到有如下恒等式:

所以具体操作时,可以对$\mathbf {x,y}$进行scale操作。

实现

参考官方的实现,目前自己复现了一版: