参考资料:

https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py

函数接口

torch中MultiheadAttention是调用torch.nn.functional中的multi_head_attention_forward函数,首先看下该函数接口:

def multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Tensor,
    in_proj_bias: Optional[Tensor],
    bias_k: Optional[Tensor],
    bias_v: Optional[Tensor],
    add_zero_attn: bool,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Optional[Tensor],
    training: bool = True,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
    use_separate_proj_weight: bool = False,
    q_proj_weight: Optional[Tensor] = None,
    k_proj_weight: Optional[Tensor] = None,
    v_proj_weight: Optional[Tensor] = None,
    static_k: Optional[Tensor] = None,
    static_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        embed_dim_to_check: total dimension of the model.
        num_heads: parallel attention heads.
        in_proj_weight, in_proj_bias: input projection weight and bias.
        bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
        add_zero_attn: add a new batch of zeros to the key and
                       value sequences at dim=1.
        dropout_p: probability of an element to be zeroed.
        out_proj_weight, out_proj_bias: the output projection weight and bias.
        training: apply dropout if is ``True``.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. This is an binary mask. When the value is True,
            the corresponding value on the attention layer will be filled with -inf.
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.
        use_separate_proj_weight: the function accept the proj. weights for query, key,
            and value in different forms. If false, in_proj_weight will be used, which is
            a combination of q_proj_weight, k_proj_weight, v_proj_weight.
        q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
        static_k, static_v: static key and value used for attention operators.
    Shape:
        Inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension.
        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
          will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
          3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
          S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
          positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.
        - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
          N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
        Outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
    """

主要输入:

  • query: (L, N, E)
  • key: (S, N, E)
  • value: (S, N, E)
  • num_heads: attention的头数

输出:

  • attn_output: (L, N, E)
  • attn_output_weights: (N, L, S)

说明:

  • query, key, value是输入,注意论文中的q, k, v是query, key, value进行线性变换后的结果;对于selfattention,三者相同;
  • attn_output,即$\text{softmax}(QK^T)VW_{o}$;attn_output_weights是权重,即$\text{softmax}(QK^T)$。
    • 这里忽略了原始论文中的伸缩因子。
  • L是目标序列长度,N是batch size, E是embedding的维度;
  • 一般情况下,会给定一个输出维度d,然后将d / num_heads作为每个多头注意力的输出维度,最后将num_heads个结果拼接得到一个维度为d的结果;

符号

为了叙述方便,后续讨论做出以下限制:

  • 简化问题,不关注batch size $N$;
  • 暂时只考虑selfattention,即query = key = value;

这里给出全部符号:

  • $X\in \mathbb R^{L\times E_1}$:对应query = key = value
  • $n$: num_heads
  • $W_Q^i, W_K^i\in \mathbb R^{E_1\times d_2}, W_V^i\in \mathbb R^{E_1\times d_3},i=1,\ldots, n$:对应$Q,K,V$的线性变换矩阵($n$个头);
  • $d_2\times n = E_2,d_3\times n =E_3$:一般$E_2,E_3$是给定的值;
  • $W_{o}\in \mathbb R^{E_3\times E_4}$:out_projection矩阵;

在Selfattention中,一般来说都有

时间以及空间复杂度讨论

这部分讨论原始的selfattention中的时间以及空间复杂度,为了简化起见,这里的时间复杂度不考虑多机以及并行等等。

前置知识

矩阵乘法的时间复杂度

在讨论之前,首先给出矩阵乘法的时间复杂度:

  • 对于$X\in \mathbb R^{d_1\times d_2}, Y\in \mathbb R^{d_2\times d_3}$,矩阵乘法$XY\in \mathbb R^{d_1\times d_3}$的时间复杂度为$\Theta(d_1d_2d_3)$

说明:

  • $XY$一共有$d_1d_3$个元素,每个元素需要做$d_2$维内积得到,即每个元素需要做$d_2$次乘法以及$d_2-1$次加法,所以每个元素的时间复杂度为$\Theta(d_2)$,因此总时间复杂度为$\Theta(d_1d_2d_3)$;

矩阵乘法以及SoftMax的时间复杂度比较

20210815更新:

  • Softmax的时间测试有点问题,后续讨论中会忽略Softmax的计算时间;

原文:

假设$Q, K\in \mathbb R^{L\times d}$,考虑计算Attention时的如下操作:

  • $W=QK^T\in \mathbb R^{L\times L}$
  • $\mathrm{Softmax}(W)\in \mathbb R^{L\times L}$

第一步的时间复杂度为$\Theta(dL^2)$,第二步的时间复杂度为$\Theta(cL^2)$,其中$c$表示计算每个元素的时间复杂度,现在的问题是$c$和$d$相比是否能够忽略?

此处由于涉及太底层的知识,所以直接通过模拟判断,最终的结论是,在单机A100上,$c$比$d$小几十倍,所以可以得出如下结论:

  • 相对于矩阵乘法,$\mathrm {Softmax}$的时间复杂度可以忽略;
  • $\mathrm{Softmax}(W)$的时间复杂度可以近似为$\Theta(dL^2)$;
实验代码以及实验结果

完整分析

有了前置知识,下面进行完整的分析。

首先回顾符号:

  • $X\in \mathbb R^{L\times E_1}$
  • $W_Q^i, W_K^i\in \mathbb R^{E_1\times d_2}, W_V^i\in \mathbb R^{E_1\times d_3},i=1,\ldots, n$;
  • $d_2\times n = E_2,d_3\times n =E_3$
  • $W_{o}\in \mathbb R^{E_3\times E_4}$:out_projection矩阵;

空间复杂度(忽略中间计算产生的结果):

矩阵 空间复杂度
$W_Q^i, W_K^i\in \mathbb R^{E_1\times d_2},i=1,\ldots,n$ $2\times E_1\times d_2 \times n =E_1E_2$
$W_V^i\in \mathbb R^{E_1\times d_3},i=1,\ldots, n$ $E_1\times d_3\times n =E_1 E_3$
$W_{o}\in \mathbb R^{E_3\times E_4}$ $E_3 E_4$

总空间复杂度:

空间复杂度为:

后续以表格的方式给出算法流程以及对应的时间复杂度:

公式 时间复杂度(忽略$\Theta$)
$Q_i=XW_Q^i\in \mathbb R^{L\times d_2},i=1,\ldots,n$ $LE_1 d_2n=LE_1E_2$
$K_i=XW_K^i\in \mathbb R^{L\times d_2},i=1,\ldots,n$ $LE_1 d_2n=LE_1 E_2$
$V_i=XW_V^i\in \mathbb R^{L\times d_3},i=1,\ldots,n$ $LE_1d_3n=LE_1E_3$
$W_i=\text{Softmax}(Q_iK_i^T)\in \mathbb R^{L\times L}$ $L^2d_2n=E_2L^2$
$O_i=W_i V_i\in \mathbb R^{L\times d_3},i=1,\ldots, n$ $L^2d_3n =E_3L^2$
$O=[O_1,\ldots,O_n] W_o\in \mathbb R^{L\times E_4}$ $LE_3 E_4$

总时间复杂度:

在Selfattention中,一般来说都有

所以总的空间时间复杂度分别为:

  • 空间:$3E^2$
  • 时间:$4E^2L + 2E L^2$

思考

待补充。

这部分对于之前的计算方式进行一些思考。

注意到