今天在使用torch/numpy中的argmax碰到一点问题,花了一些时间总算理解了其和max关系,这里简单总结一下。

参考资料:

数学角度理解

从数学角度来说,这个问题很简单,假设$m$维数组$x\in \mathbb R^{n_0\times n_1\times \ldots \times n_{m-1}}$,假设在最后一个维度求$\arg\max$,那么:

实测

以$m=3$为例,考虑如下代码:

import torch

b = 2
n = 3
d = 4

x = torch.arange(b * n * d).reshape(b, n, -1)
index = torch.argmax(x, dim=-1)
print(x)

结果:

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

现在在最后一个维度求最大值,分别使用max和argmax:

print(torch.max(x, dim=-1)[0])
print(x[index])

结果如下:

tensor([[ 3,  7, 11],
        [15, 19, 23]])
Traceback (most recent call last):
  File "argmax.py", line 11, in <module>
    print(x[index])
IndexError: index 3 is out of bounds for dimension 0 with size 2

max的结果正确了,但是argmax的结果却不正确,这是什么原因呢?从报错中可以看到应该是维度不匹配