Skip to content

The Annotated Tnn

Toeplitz Neural Network for Sequence Modeling

博客由Doreamonzzz撰写。

更新日志: - 20230313,开始撰写博客; - 20230320,完成动机以及各个部件的实现部分; - 20230524,完成校阅以及引用; - 20230527,修复两处笔误;

Toeplitz Neural Network(TNN)是一种全新的网络结构,以一种完全不同的方式进行序列建模,在单向/双向语言模型,图像分类任务上和Transformer性能相近,并且在长序列建模LRA任务上取得和S4相当的性能。这篇博客的主要目的就是以The Annotated TransformerThe Annotated S4风格介绍TNN,在阅读完这篇博客后,您将得到如下收获: 1. 了解TNN的动机和设计理念; 2. 掌握TNN各个部件的实现;

总而言之,在阅读完本博客之后,您将成为TNN的专家,并且可以将TNN应用到您的项目中,让我们开始吧。

预备知识

Token mixing and channel mixing

让我们首先从Transformer开始。Transformer作为一个网络结构已经席卷了各个领域,其核心部分主要可以由如下两个计算公式描述: 其中$\mathbf X \in \mathbb R^{n\times d}$是输入(也可以称为token matrix,其中矩阵的每一行为一个token的向量表示),$n$是序列长度,$d$是特征维度。

既然现在有两个主要模块——$\mathrm {MHA}$和$\mathrm {FFN}$,那么他们的作用是否有所不同呢?在Metaformer一文中,研究者指出,$\mathrm {MHA}$的主要作用是Token mixing,而$\mathrm {FFN}$的主要作用是Channel mixing。

这是什么意思呢?我们可以从矩阵乘法的角度清晰的理解这点:给定输入(token matrix)$\mathbf X \in \mathbb R^{n\times d}$,考虑矩阵乘法$\mathbf A \mathbf X$和$\mathbf X \mathbf B$,那么: - $\mathbf A \mathbf X$表示矩阵$\mathbf X$行的线性组合,而每一行表示一个token,即token的线性组合,所以称为token mixing; - $\mathbf X \mathbf B$表示矩阵$\mathbf X$列的线性组合,而每一列表示一个channel,即channel的线性组合,所以称为channel mixing;

在Transformer中,矩阵$\mathbf A$即为$\mathrm{Softmax}(\mathbf Q \mathbf K^{\top} /\sqrt{d})$,矩阵$\mathbf B$即为$\mathrm {FFN}$中的全连接层。

大多数对Transformer的改进都是集中在token mixing:$\mathbf A \mathbf X$的计算上,以各种各样的方式降低其运算复杂度,TNN也是使用了类似的思路,最核心的一点就是利用了相对位置编码,或者说,Toeplitz矩阵。

相对位置编码

位置编码是Transformer中的重要组成部分,一开始广为使用的是绝对位置编码(APE),这种编码的方式可以用如下计算方式概括: 其中$\mathbf w_i$表示第$i$个词的word embedding,$\mathbf p_i$表示第$i$个位置的position embedding。

后来,有研究人员发现,在序列建模中,词的相对位置信息,可能比词的绝位置信息更加重要。

例如"我年纪比你大"的语意和"你年纪比我大"完全不同,但是这两句话只是交换了"你"和"我"的位置。

于是研究人员开始将相对位置编码引入,相对位置编码的使用和绝对位置编码有所不同,其作用在Attention计算的位置: 如果写成矩阵的形式则更加直观:

这里,矩阵$\mathbf T$有一个数学名称——Toeplitz矩阵,不难看出该矩阵有$2n-1$个独立元素。

TNN的动机

有了之前的准备工作,可以引入我们工作的两个动机: 1. 既然相对位置信息如此重要,那么有没有可能只依赖于相对位置信息(Toeplitz matrix)进行token mixing呢? 1. 直观上来说,就是将Attention Matrix替换为Toeplitz matrix。 2. 假设(1)成立,那么我们需要进行的主要操作是$\mathbf T \mathbf X$,既然矩阵$\mathbf T$是一个特殊结构的矩阵,那么有没有可能加速运算呢?

我们对两个问题都进行了肯定的答复: 1. 完全可以只依赖于相对位置信息进行token mixing; 2. 由于矩阵的特殊性,可以将运算复杂度由$O(n^2 d)$降低为$O(nd\log n)$;

可以看到,我们的动机极其简单和优雅,最核心的思路就是将$\mathrm{Softmax}(\mathbf Q \mathbf K^{\top} / \sqrt {d})$替换为$\mathbf T$,但是,这种简单的替换就可以拥有比各种花哨更改更好的性能,这就更加验证了相对位置信息在序列建模中的重要性。

TNN的实现

准备工作

接下来的问题就是如何实现TNN,在此之前,我们对之前的公式做一定的调整。

在之前的讨论中,我们提到了$\mathbf T \mathbf X$可以高效实现,其中$\mathbf T\in \mathbb R^{n\times n}, \mathbf X \in \mathbb R^{n\times d}$,这种情况相当于每个channel共享同一个Toeplitz matrix,但是注意到我们可以让不同的channel使用不同的Toeplitz matrix,我们经验上发现,这样一定程度上可以增大模型的表达性,所以在TNN中,每个channel使用了不同的Toeplitz matrix。注意到形状为$n\times n$的Toeplitz matrix实际上只有$2n-1$个独立元素,为了方便后续讨论,我们定义如下映射:$f: \mathbb R^{(2n-1)\times 1} \to \mathbb R^{n\times n}$: 该映射的作用是将维度为$(2n-1)\times 1$的向量填充为$n\times n$的Toeplitz matrix。

结合之前的记号,我们定义为Tno算子(Toeplitz neural operator)为:

备注:这里的记号$\mathbf T\in \mathbb R^{(2n-1)\times d}$和一开始含义有所不同,注意不要搞混。

在开始正式的实现之前,我们先引入一些必要的依赖库以及一些辅助函数:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

def get_activation_fn(activation):
    if activation == "gelu":
        return F.gelu
    elif activation == "relu":
        return F.relu
    elif activation == "elu":
        return F.elu
    elif activation == "sigmoid":
        return F.sigmoid
    elif activation == "exp":
        return torch.exp
    elif activation == "leak":
        return F.leaky_relu
    elif activation == "1+elu":
        def f(x):
            return 1 + F.elu(x)
        return f
    elif activation == "2+elu":
            def f(x):
                return 2 + F.elu(x)
            return f
    elif activation == "silu":
        return F.silu
    else:
        return lambda x: x

Tno的实现

Naive实现

最朴素的实现自然是利用定义进行实现,例如如下代码中,我们使用4重循环,外面两重循环遍历batch, channel维度,第三重循环遍历输出位置,最后一重循环遍历求和项,注意到我们的$\mathbf T[:, i]$输入形式为$t_{-n+1}, ... , t_{-1}, t_0, t_1, ... , t_{n - 1}$,第三重循环遍历到$i$时,涉及的$t$为$t_{i}, t_{i-1},\ldots, t_{i-n+1}$,而$n - 1 + i$是$t_{i}$在$\mathbf T[:, i]$的实际索引:

def tno_naive(x, t):
    # x: (b, n, d)
    # t: (2n - 1, d), t_(-(n - 1)), ... , t_(-1), t_0, t_1, ... , t_(n - 1) 
    b, n, d = x.shape
    o = torch.zeros_like(x).to(x)
    for b_ in range(b):
        for d_ in range(d):
            for i in range(n):
                for j in range(n):
                    o[b_][i][d_] += t[n - 1 + i - j][d_] * x[b_][j][d_]

    return o

这种实现显然太低效,但是至少我们有了一个正确的版本,这对我们后续改进算法也是有帮助的,不难看出这样计算的时间复杂度为$O(n^2d)$,空间复杂度为$O(nd)$(忽略batch维度)。

Matrix production实现

第二种实现是并行版本,其思路就是先构造Toeplitz matrix,然后利用矩阵乘法进行计算。最主要的部分是将映射$f$实现出来,代码基于此处,主要思路是先将输入改写为$t_0, t_{-1}, ... , t_{1-n}, t_{n - 1}, ... , t_1$,然后构造index $0, 1, \ldots,n -1, -(n - 1), ..., -1$,将输入映射到Toeplitz matrix,最后得到Toeplitz matrix进行矩阵乘法:

def tno_matrix(x, t):
    # x: (b, n, d)
    # t: (2n - 1, d), t_(-(n - 1)), ... , t_(-1), t_0, t_1, ... , t_(n - 1) 
    n = x.shape[1]
    t = t.unsqueeze(0)
    # c: t_0, t_1, ... , t_(n - 1)
    c = t[:, n - 1:]
    # r: t_0, t_(-1), ... , t_(-(n - 1))
    r = t[:, :n].flip(1)
    # vals: [t_0, t_(-1), ... , t_(-(n - 1)), t_(n - 1), ... , t_1]
    vals = torch.cat([r, c[:, 1:].flip(1)], dim=-2)
    i, j = torch.ones(n, n).nonzero().T
    t_matrix = vals[:, j - i].reshape(n, n, -1)
    o = torch.einsum("n m d, b m d -> b n d", t_matrix, x)

    return o

这种实现的好处是可以利用矩阵乘法,尽管复杂度依然为$O(n^2d)$,但实际效率会快很多;但是由于要构造Toeplitz matrix,所以空间复杂度为$O(n^2d)$,并且这部分还是一个很大的IO开销,所以实际中的速度并不会很快。

FFT实现

有了之前的铺垫,可以看出前两种方法无论是时间复杂度和空间复杂度相比Attention并没有什么优势,那么有没有办法解决这点呢?回答是肯定的,这就需要FFT这把利刃。后续的讨论涉及到一些数学知识,这里先高度概括一下思路: 1. 给出Circulant matrix的快速矩阵乘法算法; 2. 建立Toeplitz marix和Circulant matrix的关系;

Circulant matrix

定义

矩阵$\mathbf C\in \mathbb R^{n\times n}$是一个Circulant matrix当且仅当$\mathbf C_{ij}= c_{(i-j + n )\bmod n}$ ,即: 关于Circulant matrix,有如下重要性质:

Circulant matrix $\mathbf C\in \mathbb R^{n\times n}$正交相似于对角阵$\mathbf \Lambda$,特别地,相似矩阵$\mathbf F$是$n\times n$ DFT矩阵: 证明可以参考这里

快速矩阵乘法

现在考虑matrix-vector production操作$\mathbf M \mathbf x, \mathbf M\in \mathbb R^{n\times n}, \mathbf x\in \mathbb R^{n\times 1}$,那么:

  • 如果$\mathbf M$为一般的矩阵,那么该计算的时间复杂度为$O(n^2)$;
  • 如果$\mathbf M$为DFT矩阵,那么该计算的时间复杂度为$O(n \log n)$;

基于上述事实,考虑$\mathbf M=\mathbf C$为Circulant matrix的情形,那么: 该计算可以分解为几个步骤:

  • $\mathbf x_{\mathrm{fft}}=\mathbf{Fx}$;
  • $\mathbf c_{\mathrm{fft}}=\mathbf F[c_0,c_1,\ldots, c_{n-1}]^\top$;
  • $\mathbf o_{\mathrm{fft}}=\mathbf x_{\mathrm{fft}}\odot \mathbf c_{\mathrm{fft}}$;
  • $\mathbf o= \mathbf F^{\top} \mathbf o_{\mathrm{fft}}$;

其中$\odot$表示element-wise production,可以看出,算法的总时间复杂度为$O(n\log n)$,空间复杂度为$O(n)$,所以Circulant matrix对应的矩阵乘法是高效的。

实现

有了之前的说明,不难利用fft实现上述计算:

def circulant_fft(x, c):
    # x: (b, n, d)
    # c: (n, d), c_0, c_1, ... , c_(n - 1) 
    n = x.shape[1]
    c = c.unsqueeze(0)
    x_fft = torch.fft.rfft(x, n, dim=-2)
    c_fft = torch.fft.rfft(c, n, dim=-2)
    o_fft = x_fft * c_fft
    o = torch.fft.irfft(o_fft, n, dim=-2)

    return o
小结

现在我们已经有了一个关于Circulant matrix的高效矩阵乘法,那么下一个问题就是建立Toeplitz matrix和Circulant matrix的关系。

Toeplitz matrix

定义

矩阵$\mathbf T\in \mathbb R^{n\times n}$是一个Toeplitz matrix当且仅当$\mathbf T_{ij}= t_{i-j}$,即 从形式上来看,Toeplitz matrix和Circulant matrix非常像,唯一的区别在于前者的独立元素数量为$2n-1$,后者的独立元素数量为$n$,那么一个简单的思路就是将Toeplitz matrix嵌入到一个阶数大于等于$2n-1$矩阵中,而这个矩阵本生是一个Circulant matrix,下面来看下这是如何具体操作的。

可以将Toeplitz matrix $\mathbf T\in \mathbb R^{n\times }$嵌入到Circulant matrix $\mathbf C \in \mathbb R^{2n\times 2n}$中: 即, 使用分块矩阵的符号,我们可以定义: 有了上述准备工作,可以得到Toeplitz matrix-vector production的快速算法。

快速矩阵乘法

对于向量$\mathbf x\in \mathbb R^{n}$, 定义: 所以, 因此: 关于时间复杂度,注意到我们是将$n\times n$的Toeplitz matrix嵌入到一个$2n\times 2n$的Circulant matrix中,所以时间复杂度仍然为$O(n\log n)$。

实现

和Circulant matrix的情形类似,可以利用fft实现上述计算:

def tno_fft(x, t):
    # x: (b, n, d)
    # t: (2 * n, d), t0, t1, ..., t(n-1), t0, t_(-(n-1)), ... , t_(-1)
    n = x.shape[1]
    t = t.unsqueeze(0)
    x_fft = torch.fft.rfft(x, 2 * n, dim=-2)
    t_fft = torch.fft.rfft(t, 2 * n, dim=-2)
    o_fft = x_fft * t_fft
    o = torch.fft.irfft(o_fft, 2 * n, dim=-2)[:, :n]

    return o

验证实现

在之前的讨论中,我们给出了Tno的三种实现方式,在本节中,我们将验证这些实现的正确性。

b = 2
n = 16
d = 128

t_zero = torch.randn(1, d)
# t1, ..., t(n-1)
t_pos = torch.randn(n - 1, d)
# t-(n-1), ... , t-1
t_neg = torch.randn(n - 1, d)
t1 = torch.cat([t_neg, t_zero, t_pos], dim=0).cuda()
t2 = torch.cat([t_zero, t_pos, t_zero, t_neg], dim=0).cuda()
x = torch.randn(b, n, d).cuda()

o1 = tno_naive(x, t1)
o2 = tno_matrix(x, t1)
o3 = tno_fft(x, t2)

print(f"The output error between tno_naive and tno_matrix is {torch.norm(o1 - o2)}")
print(f"The output error between tno_naive and tno_matrix is {torch.norm(o1 - o3)}")
The output error between tno_naive and tno_matrix is 2.414959999441635e-05
The output error between tno_naive and tno_matrix is 5.38119456905406e-05

补充

现在我们已经完成了大部分内容,这里最后补充如何将Tno适配到Autoregressive Language Model(causal)的情形。和Attention类似,只要保证Toeplitz matrix的上三角部分为$0$即可,即: 在实现时,注意到fft是zero padding,所以只需要将输入:

t2 = torch.cat([t_zero, t_pos, t_zero, t_neg], dim=0).cuda()

修改为下式即可:

t2 = torch.cat([t_zero, t_pos, t_zero], dim=0).cuda()

小结

在本节中,我们从naive的算法开始,最终得到了一个基于FFT算法的高效实现,并且给出处理单向情形的方案。

Rpe的实现

注意到Tno的计算涉及到$x,t$,$x$是输入,$t$是相对位置系数,所以下一步就是如何计算$t$。对于序列长度为$n$,特征维度为$d$的模型,我们一共有$(2n-1)\times d$个系数,所以接下来的问题就是如何得到这些系数。

Naive实现

最简单的思路就是直接给模型增加$(2n-1)\times d$个参数,但是这样做有几个问题: 1. 当序列长度$n$比较大的时候,模型参数量会非常多; 2. 尽管我们有$(2n-1)\times d$个系数,但是对于每个channel的$2n-1$个系数,不能完全假设他们是独立的,例如$t_1$和$t_{-1}$必然有内在联系; 3. 无法处理任意长的序列; 1. 这点可以理解为,当超过最大序列长度时,没有对应的系数,所以模型也没有外推性

那么是否有办法解决这些问题呢?回答是肯定的。

Relative Position Encoder

对于问题1,2,我们利用某种方式参数化这$(2n-1)\times d$个参数即可,最简单方式就是使用神经网络,特别的,我们使用的是一个名为Relative Position Encoder(RPE)的网络,网络的输入是1维实数$-(n-1), \ldots, (n-1)$,输出是$d$维特征。在使用时,我们会输入$[-(n-1),\ldots, (n-1)]^{\top} \in \mathbb R^{2n-1}$,输出的形状是$(2n-1)\times d$。

对于问题3,我们现在可以一定程度上解决这个问题,现在只要将相对位置(超出训练时的最大训练长度也可)输入到RPE中,即可得到对应系数。但是这样还远远不够,因为这种方式只是让模型“强行”计算了一个值,为了使得性能正常,我们参考了Alibi的方案,使用了指数衰减的形式,即: 其中$\lambda$是一个超参,我们在$n=512$时选择$\lambda=0.99$。

实现Relative Position Encoder

有了之前的讨论,我们给出Relative Position Encoder的实现,本质是就是一个全连接网络,加上归一化和激活函数:

class Rpe(nn.Module):
    def __init__(
        self, 
        dim, 
        outdim, 
        residual, 
        act="relu", 
        bias=True, 
        layers=3, 
    ):
        super().__init__()

        self.residual = residual
        self.outdim = outdim
        self.pos_dim = dim
        self.act = act
        self.pos_proj = nn.Linear(1, self.pos_dim, bias=bias)
        self.layers = nn.ModuleList([])
        for i in range(layers):
            self.layers.append(
                nn.Sequential(
                    nn.LayerNorm(self.pos_dim),
                    self.get_act(),
                    nn.Linear(self.pos_dim, self.pos_dim, bias=bias),
                )
            )
        self.out = nn.Sequential(
            nn.LayerNorm(self.pos_dim),
            self.get_act(),
            nn.Linear(self.pos_dim, self.outdim, bias=bias),
        )

    def get_act(self):
        if self.act == "silu":
            return nn.SiLU(inplace=True)
        else:
            return nn.ReLU(inplace=True)

    def forward(self, biases):
        x = self.pos_proj(biases)
        if self.residual:
            for m in self.layers:
                x = m(x) + x
        else:
            for m in self.layers:
                x = m(x)
        x = self.out(x)

        return x

将Tno和Rpe合并

在我们的原始实现中,Rpe是和Tno合并在一起的,完整的实现如下:

class Tno(nn.Module):
    def __init__(
        self, 
        h, 
        dim, 
        rpe_dim, 
        causal=False, 
        use_decay=False, 
        residual=False, 
        act="relu", 
        par_type=1, 
        gamma=0.99,
        bias=True,
        layers=3,
    ):
        super().__init__()

        self.h = h
        self.dim = dim
        self.causal = causal
        self.par_type = par_type
        self.zero_value = 0
        self.use_decay = use_decay
        if self.use_decay:
            self.gamma = nn.Parameter(torch.ones(h, 1, dim) * gamma, requires_grad=False)

        self.rpe = Rpe(
            dim=rpe_dim, 
            outdim=h * dim, 
            residual=residual,
            act=act,
            bias=bias, 
            layers=layers,
        )

        if self.causal:
            self.forward = self.forward_causal
        else:
            self.forward = self.forward_non_causal

    def get_pos(self, n):
        if self.par_type == 1:
            index = torch.arange(1, 1 + n).reshape(n, -1) * 1.0
        elif self.par_type == 2:
            index = torch.arange(1, 1 + n).reshape(n, -1) * 1.0 / n
        elif self.par_type == 3:
            index = torch.exp(torch.arange(1, 1 + n).reshape(n, -1) * 1.0 / n)

        return index

    def get_zero(self):
        index = torch.zeros(1).reshape(1, -1) * 1.0
        if self.par_type == 3:
            index = torch.exp(index)

        return index

    def get_neg(self, n):
        if self.causal:
            index = torch.ones(self.h * n * self.dim).reshape(self.h, n, self.dim) * self.zero_value
        else:
            if self.par_type == 1:
                index = -torch.arange(1, 1 + n).flip(0).reshape(n, -1) * 1.0
            elif self.par_type == 2:
                index = -torch.arange(1, 1 + n).flip(0).reshape(n, -1) * 1.0 / n

        return index

    def rpe_transform(self, x):
        # n, 1 -> n, (d * h)
        res = self.rpe(x)
        # n, (d * h) -> h, n, d
        res = rearrange(res, 'n (h d) -> h n d', h=self.h)

        return res

    def forward_causal(self, x, dim=-2):
        # x: b, h, n, d
        n = x.shape[dim]
        # a0, a1, ... , a(n-1), a0, a(-(n-1)), ... , a(-1)
        ##### coef
        # 1, d, 1 -> h, 1, d
        zero = self.rpe_transform(self.get_zero().to(x))
        pos = self.rpe_transform(self.get_pos(n - 1).to(x))

        if self.use_decay:
            coef = torch.arange(1, n).reshape(1, -1, 1).to(x)
            gamma = self.gamma
            gamma = gamma ** coef
            pos = gamma * pos
        a = torch.cat([zero, pos, zero], dim=1)
        a = self.act_fun(a)

        # x: b, h, n, d
        # a: h, l, d
        output = self.compute(x, a, dim, n)

        return output

    def forward_non_causal(self, x, dim=-2):
        # x: b, h, n, d
        n = x.shape[dim]
        # a0, a1, ... , a(n-1), a0, a(-(n-1)), ... , a(-1)
        ##### coef
        # 1, d, 1 -> h, 1, d
        zero = self.rpe_transform(self.get_zero().to(x))
        pos = self.rpe_transform(self.get_pos(n - 1).to(x))
        neg_index = self.get_neg(n - 1).to(x)
        if self.causal:
            neg = neg_index
        else:
            neg = self.rpe_transform(neg_index)

        if self.use_decay:
            coef = torch.arange(1, n).reshape(1, -1, 1).to(x)
            gamma = self.gamma
            gamma = gamma ** coef
            pos = gamma * pos
            neg = torch.flip(gamma, dims=[1]) * neg
        a = torch.cat([zero, pos, zero, neg], dim=1)
        a = self.act_fun(a)
        # x: b, h, n, d
        # a: h, l, d
        output = self.compute(x, a, dim, n)

        return output

    def compute(self, x, a, dim, n):
        # x: b, h, n, d
        # a: h, n, d
        y = torch.fft.rfft(x, 2 * n, dim=dim)
        v = torch.fft.rfft(a, 2 * n, dim=dim).unsqueeze(0)
        u = v * y
        output = torch.fft.irfft(u, 2 * n, dim=dim)[:, :, :n, :]

        return output

Tnn layer的实现

有了之前的铺垫,我们可以介绍Tnn Layer,该模块包含一个Token mixer(GTU)以及一个Channel mixer(GLU),由于GLU和GTU非常相似,所以我们从GLU开始介绍。

GLU

GLU是利用Gate的形式达到Channel mixing的作用,写成数学公式为: 实现如下:

class GLU(nn.Module):
    def __init__(self, d1, d2, act_fun, fina_act="None", dropout=0.0, bias=True):
        super().__init__()

        self.l1 = nn.Linear(d1, d2, bias=bias)
        self.l2 = nn.Linear(d1, d2, bias=bias)
        self.l3 = nn.Linear(d2, d1, bias=bias)
        self.act_fun = get_activation_fn(act_fun)
        self.p = dropout
        if self.p > 0.0:
            self.dropout = nn.Dropout(p=dropout)
        self.fina_act = get_activation_fn(fina_act)

    def forward(self, x):
        o1 = self.l1(x)
        weight = self.act_fun(o1)
        if self.p > 0.0:
            weight = self.dropout(weight)
        o2 = self.l2(x)
        output = weight * o2
        output = self.l3(output)
        output = self.fina_act(output)

        return output

GTU

GTU参考了GLU的思路,唯一的不同是在其中一个分支上使用了Tno,并且增加一个激活函数,写成数学公式即为: 实现如下:

class Gtu(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        bias=True,
        act_fun="silu",
        causal=False,
        expand_ratio=3,
        use_norm=False,
        norm_type="layernorm",
        use_decay=False,
        rpe_layers=3,
        rpe_embedding=512,
        rpe_act="relu",
        normalize=False,
        par_type=1,
        residual=False,
        gamma=0.99,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.expand_ratio = expand_ratio
        self.num_heads = num_heads
        self.normalize = normalize

        d1 = int(self.expand_ratio * embed_dim)
        d1 = (d1 // self.num_heads) * self.num_heads
        self.head_dim = d1 // num_heads
        # linear projection
        self.v_proj = nn.Linear(embed_dim, d1, bias=bias)
        self.u_proj = nn.Linear(embed_dim, d1, bias=bias)
        self.o = nn.Linear(d1, embed_dim, bias=bias)
        self.act = get_activation_fn(act_fun)
        # tno
        self.toep = Tno(
            h=num_heads, 
            dim=self.head_dim,
            rpe_dim=rpe_embedding, 
            causal=causal, 
            use_decay=use_decay, 
            residual=residual,
            act=rpe_act,
            par_type=par_type,
            gamma=gamma,
            bias=bias,
            layers=rpe_layers,
        )
        # norm
        self.norm_type = norm_type
        self.use_norm = use_norm

    def forward(self, x):
        # x: b, n, d
        num_heads = self.num_heads

        u = self.act(self.u_proj(x))
        v = self.act(self.v_proj(x))
        # reshape
        v = rearrange(v, 'b n (h d) -> b h n d', h=num_heads)
        output = self.toep(v, dim=-2, normalize=self.normalize)
        output = rearrange(output, 'b h n d -> b n (h d)')
        output = u * output
        output = self.o(output)

        return output

TnnLayer

有了之前的准备工作,我们很容易实现出TnnLayer,因为这只不过是GTU和GLU的堆叠:

class TnnLayer(nn.Module):
    def __init__(
        self, 
        dim, 
        num_heads,
        rpe_embedding,
        glu_dim,
        # model params
        prenorm=True,
        norm_type="layernorm",
        # gtu params
        causal=False,
        gtu_act="silu",
        expand_ratio=3,
        use_decay=False,
        gamma=0.999,
        # rpe params
        rpe_act="relu",
        rpe_layers=3,
        # glu params
        glu_act="silu",
    ):
        super().__init__()
        self.token_mixer = Gtu(
            # gtu params
            embed_dim=dim,
            num_heads=num_heads,
            act_fun=gtu_act,
            norm_type=norm_type,
            causal=causal,
            expand_ratio=expand_ratio,
            use_decay=use_decay,
            gamma=gamma,
            # rpe params
            rpe_embedding=rpe_embedding,
            rpe_act=rpe_act,
            rpe_layers=rpe_layers,
        )

        self.token_norm = nn.LayerNorm(dim)
        self.feature_norm = nn.LayerNorm(dim)

        self.feature_mixer = GLU(
            d1=dim, 
            d2=glu_dim,
            act_fun=glu_act,
        )

    def forward(self, x):
        x = x + self.token_mixer(self.token_norm(x))
        x = x + self.feature_mixer(self.feature_norm(x))

        return x

在使用时,您只需要将TransformerLayer替换成TnnLayer即可。

小结

在本节中,我们完成了TnnLayer的实现,有了之前的铺垫工作,这一切并不困难。现在,您已经可以将Tnn应用到您的项目中了。

全文总结

通过之前的内容,您应该对TNN有所了解,这里,让我们对全文的核心进行总结:

  • Transformer可以分为Token mixing和Channel mixing;
  • Attention的作用是Token mixing,而相对位置信息对Attention很重要,我们提出使用相对位置信息(Toepltiz matrix)来代替Attention Matrix;
  • 使用Toeplitz matrix进行矩阵乘法可以加速,所以我们的方法理论上速度很快;
  • Toeplitz matrix的系数可以使用Rpe进行参数化,从而减少参数,结合指数衰减可以得到外推性;

当然,TNN还有很多问题存在,例如:

  • 为什么相对位置信息就足够进行序列建模?
  • TNN真的只使用了相对位置信息吗?
  • TNN能达到理论速度上界吗?
  • TNN不能做哪些任务?
  • TNN有哪些先验假设?

关于这些问题,我们将在后续的博客中回答,期待您的再次阅读。