torch/numpy中max和argmax的关系
今天在使用torch/numpy中的argmax碰到一点问题,花了一些时间总算理解了其和max关系,这里简单总结一下。
参考资料:
- https://stackoverflow.com/questions/40357335/numpy-how-to-get-a-max-from-an-argmax-result
- https://numpy.org/doc/stable/reference/generated/numpy.ix_.html
- https://pytorch.org/docs/stable/generated/torch.meshgrid.html
- https://docs.scipy.org/doc/numpy-1.10.1/reference/arrays.indexing.html#advanced-indexing
数学角度理解
从数学角度来说,这个问题很简单,假设$m$维数组$x\in \mathbb R^{n_0\times n_1\times \ldots \times n_{m-1}}$,假设在最后一个维度求$\arg\max$,那么:
实测
以$m=3$为例,考虑如下代码:
import torch
b = 2
n = 3
d = 4
x = torch.arange(b * n * d).reshape(b, n, -1)
index = torch.argmax(x, dim=-1)
print(x)
结果:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
现在在最后一个维度求最大值,分别使用max和argmax:
print(torch.max(x, dim=-1)[0])
print(x[index])
结果如下:
tensor([[ 3, 7, 11],
[15, 19, 23]])
Traceback (most recent call last):
File "argmax.py", line 11, in <module>
print(x[index])
IndexError: index 3 is out of bounds for dimension 0 with size 2
max的结果正确了,但是argmax的结果却不正确,这是什么原因呢?从报错中可以看到应该是维度不匹配
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Doraemonzzz!
评论
ValineLivere