https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py

## 函数接口

def multi_head_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
in_proj_weight: Tensor,
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
need_weights: bool = True,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is True.
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
Shape:
Inputs:
- query: :math:(L, N, E) where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:(S, N, E), where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:(S, N, E) where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:(N, S) where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
will be unchanged. If a BoolTensor is provided, the positions with the
value of True will be ignored while the position with the value of False will be unchanged.
- attn_mask: 2D mask :math:(L, S) where L is the target sequence length, S is the source sequence length.
3D mask :math:(N*num_heads, L, S) where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with True
are not allowed to attend while False values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- static_k: :math:(N*num_heads, S, E/num_heads), where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- static_v: :math:(N*num_heads, S, E/num_heads), where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
Outputs:
- attn_output: :math:(L, N, E) where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:(N, L, S) where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""

• query: (L, N, E)
• key: (S, N, E)
• value: (S, N, E)

• attn_output: (L, N, E)
• attn_output_weights: (N, L, S)

• query, key, value是输入，注意论文中的q, k, v是query, key, value进行线性变换后的结果；对于selfattention，三者相同；
• attn_output，即$\text{softmax}(QK^T)VW_{o}$；attn_output_weights是权重，即$\text{softmax}(QK^T)$。
• 这里忽略了原始论文中的伸缩因子。
• L是目标序列长度，N是batch size， E是embedding的维度；

## 符号

• 简化问题，不关注batch size $N$；
• 暂时只考虑selfattention，即query = key = value；

• $X\in \mathbb R^{L\times E_1}$：对应query = key = value
• $n$: num_heads
• $W_Q^i, W_K^i\in \mathbb R^{E_1\times d_2}, W_V^i\in \mathbb R^{E_1\times d_3},i=1,\ldots, n$：对应$Q,K,V$的线性变换矩阵（$n$个头）；
• $d_2\times n = E_2,d_3\times n =E_3$：一般$E_2,E_3$是给定的值；
• $W_{o}\in \mathbb R^{E_3\times E_4}$：out_projection矩阵；

## 时间以及空间复杂度讨论

### 前置知识

#### 矩阵乘法的时间复杂度

• 对于$X\in \mathbb R^{d_1\times d_2}, Y\in \mathbb R^{d_2\times d_3}$，矩阵乘法$XY\in \mathbb R^{d_1\times d_3}$的时间复杂度为$\Theta(d_1d_2d_3)$

• $XY$一共有$d_1d_3$个元素，每个元素需要做$d_2$维内积得到，即每个元素需要做$d_2$次乘法以及$d_2-1$次加法，所以每个元素的时间复杂度为$\Theta(d_2)$，因此总时间复杂度为$\Theta(d_1d_2d_3)$；

#### 矩阵乘法以及SoftMax的时间复杂度比较

20210815更新：

• Softmax的时间测试有点问题，后续讨论中会忽略Softmax的计算时间；

• $W=QK^T\in \mathbb R^{L\times L}$
• $\mathrm{Softmax}(W)\in \mathbb R^{L\times L}$

• 相对于矩阵乘法，$\mathrm {Softmax}$的时间复杂度可以忽略；
• $\mathrm{Softmax}(W)$的时间复杂度可以近似为$\Theta(dL^2)$；

### 完整分析

• $X\in \mathbb R^{L\times E_1}$
• $W_Q^i, W_K^i\in \mathbb R^{E_1\times d_2}, W_V^i\in \mathbb R^{E_1\times d_3},i=1,\ldots, n$；
• $d_2\times n = E_2,d_3\times n =E_3$
• $W_{o}\in \mathbb R^{E_3\times E_4}$：out_projection矩阵；

$W_Q^i, W_K^i\in \mathbb R^{E_1\times d_2},i=1,\ldots,n$ $2\times E_1\times d_2 \times n =E_1E_2$
$W_V^i\in \mathbb R^{E_1\times d_3},i=1,\ldots, n$ $E_1\times d_3\times n =E_1 E_3$
$W_{o}\in \mathbb R^{E_3\times E_4}$ $E_3 E_4$

$Q_i=XW_Q^i\in \mathbb R^{L\times d_2},i=1,\ldots,n$ $LE_1 d_2n=LE_1E_2$
$K_i=XW_K^i\in \mathbb R^{L\times d_2},i=1,\ldots,n$ $LE_1 d_2n=LE_1 E_2$
$V_i=XW_V^i\in \mathbb R^{L\times d_3},i=1,\ldots,n$ $LE_1d_3n=LE_1E_3$
$W_i=\text{Softmax}(Q_iK_i^T)\in \mathbb R^{L\times L}$ $L^2d_2n=E_2L^2$
$O_i=W_i V_i\in \mathbb R^{L\times d_3},i=1,\ldots, n$ $L^2d_3n =E_3L^2$
$O=[O_1,\ldots,O_n] W_o\in \mathbb R^{L\times E_4}$ $LE_3 E_4$

• 空间：$3E^2$
• 时间：$4E^2L + 2E L^2$