参考文章:
- https://blog.csdn.net/weixin_45999482/article/details/115728139?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_title-1&spm=1001.2101.3001.4242
- https://blog.csdn.net/ao1886/article/details/107749007?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_title-5&spm=1001.2101.3001.4242
1. 参考一
仔细看了PyTorch的文档才搞懂这两个函数
PyTorch: torch.Tensor.scatter
另一个文档: pytorch_scatter
scatter()
这个是scatter_()的out-of-place版本,即函数修改的不是原tensor
在vscode里面看这个函数有两种:
1 | scatter(self: Tensor, dim: _int, index: Tensor, src: Tensor) -> Tensor |
1 | scatter(self: Tensor, dim: _int, index: Tensor, value: Number) -> Tensor |
两个的区别在于最后一个参数,可以用Tensor作为src进行填充,也可以指定某个数值作为填充
scatter_()
一句话总结:在一个tensor的基础上,在dim
维上,根据index
选择src
的一些数填到原始的那个tensor里。
对于scatter,向原始tensor填数得到另外一个tensor,原tensor不变;对于scatter
2. 参考二
pytorch 深入理解 tensor.scatter_ ()用法
在 pytorch 库下理解 torch.tensor.scatter()的用法。作者在网上搜索了很多方法,最后还是觉得自己写一篇更为详细的比较好,转载请注明。
首先,scatter() 和 scatter_() 的作用是一样的,但是 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会修改原先的 Tensor。
1 API格式
1 | torch.Tensor.scatter_(dim, index, src) → Tensor |
字面意思:对一个 torch.Tensor 进行操作,dim,index,src三个为输入的参数。
- dim 就是在哪个维度进行操作,注意,dim 的不同,在其他条件相同的条件下得到的output 也不同。
- index 是输入的索引。
- src 就是输入的向量,也就是 input。
最后,函数返回一个 Tensor。
2 具体示例
1 | import torch as th |
下面来解释一下,b,c 内的元素分别是怎么得到的。
2.1 dim = 0 下的结果分析
先说 b,也就是 dim =0 下得到的结果。我们来看下官方给的说明文字:
1 | self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 |
因为这时 dim = 0,且只有 2 个维度,所以我们只用看第一行就行。
self [index[i][j]] [j] = src[i][j] # if dim == 0
仅用这一个公式就确定了 b 中所有元素的取值,与 a 的映射关系。这里等号左边的 self 可看做 output,也就是 b;src 是我们的输入向量,也就是 a。这里的 i,j 分别是输入向量 src 的 size 的取值。比如,本例中 a 的 size 为 (2,5),也就是说,对于 a 中的元素,i 的取值为 0,1;j 的取值为 0,1,2,3,4。a 中的元素的索引也就是(0,0),(0,1),… (0,4);(1,0),(1,1),…(1,4) 完毕,一共 2*5 = 10 个元素。
了解了这些以后,通过举例来说明 b 中的元素都是如何确定的。
1 | index = th.LongTensor([[0, 1, 2, 0, 0],[2, 0, 0, 1, 2]]), |
dim = 1的时候,同理。只是换了一种映射机制,如法炮制。
有任何关于内容不够详细,解释不清,错误等欢迎留言。转载请注明,支持原创,谢谢。