参考文章:
- https://www.cnblogs.com/52dxer/p/13771279.html
- https://blog.csdn.net/real_ilin/article/details/105874641?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_baidulandingword-1&spm=1001.2101.3001.4242
- Pytorch Tensor维度变换
- pytorch扩展tensor的一个维度或多个维度
1. 维度的扩展
函数: unsqueeze()1 | # a是一个4维的 |
输出结果
1 | a.shape |
==注意,第5维前加1维,就会出错==
1 | # print(a.unsqueeze(5).shape) |
1 | # b是一个1维的 |
输出结果
1 | b.shape |
2. 挤压维度
函数:squeeze()1 | # 挤压维度,只会挤压shape为1的维度,如果shape不是1的话,当前值就不会变 |
输出结果
1 | torch.Size([1, 32, 1, 2]) |
3. 维度扩张
函数1:expand():扩张到多少1 | # shape的扩张 |
输出结果
1 | torch.Size([1, 32, 1, 1]) |
1 | d=torch.randn([1,32,4,5]) |
输出结果
1 | torch.Size([1, 32, 4, 5]) |