0%

Pytorch之tensor维度的扩展,挤压,扩张



参考文章:

  1. https://www.cnblogs.com/52dxer/p/13771279.html
  2. https://blog.csdn.net/real_ilin/article/details/105874641?utm_medium=distribute.pc_relevant.none-task-blog-baidujs_baidulandingword-1&spm=1001.2101.3001.4242
  3. Pytorch Tensor维度变换
  4. pytorch扩展tensor的一个维度或多个维度
数据本身不发生改变,数据的访问方式发生了改变

1. 维度的扩展

函数: unsqueeze()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# a是一个4维的
a = torch.randn(4, 3, 28, 28)
print('a.shape\n', a.shape)

print('\n维度扩展(变成5维的):')
print('第0维前加1维')
print(a.unsqueeze(0).shape)
print('第4维前加1维')
print(a.unsqueeze(4).shape)
print('在-1维前加1维')
print(a.unsqueeze(-1).shape)
print('在-4维前加1维')
print(a.unsqueeze(-4).shape)
print('在-5维前加1维')
print(a.unsqueeze(-5).shape)

输出结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
a.shape
torch.Size([4, 3, 28, 28])

维度扩展(变成5维的):
第0维前加1维
torch.Size([1, 4, 3, 28, 28])
第4维前加1维
torch.Size([4, 3, 28, 28, 1])
在-1维前加1维
torch.Size([4, 3, 28, 28, 1])
在-4维前加1维
torch.Size([4, 1, 3, 28, 28])
在-5维前加1维
torch.Size([1, 4, 3, 28, 28])

==注意,第5维前加1维,就会出错==

1
2
# print(a.unsqueeze(5).shape)
# Errot:Dimension out of range (expected to be in range of -5, 4], but got 5)
连续扩维:unsqueeze()
1
2
3
4
5
6
7
8
9
10
11
12
# b是一个1维的
b = torch.tensor([1.2, 2.3])
print('b.shape\n', b.shape)
print()
# 0维之前插入1维,变成1,2]
print(b.unsqueeze(0))
print()
# 1维之前插入1维,变成2,1]
print(b.unsqueeze(1))

# 连续扩维,然后再对某个维度进行扩张
print(b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)

输出结果

1
2
3
4
5
6
7
8
b.shape
torch.Size([2])

tensor([[1.2000, 2.3000]])

tensor([[1.2000],
[2.3000]])
torch.Size([1, 2, 1, 1])

2. 挤压维度

函数:squeeze()
1
2
3
4
5
6
7
# 挤压维度,只会挤压shape为1的维度,如果shape不是1的话,当前值就不会变
c = torch.randn(1, 32, 1, 2)
print(c.shape)
print(c.squeeze(0).shape)
print(c.squeeze(1).shape) # shape不是1,不会变
print(c.squeeze(2).shape)
print(c.squeeze(3).shape) # shape不是1,不会变

输出结果

1
2
3
4
5
torch.Size([1, 32, 1, 2])
torch.Size([32, 1, 2])
torch.Size([1, 32, 1, 2])
torch.Size([1, 32, 2])
torch.Size([1, 32, 1, 2])

3. 维度扩张

函数1:expand():扩张到多少
1
2
3
4
5
6
# shape的扩张
# expand():对shape为1的进行扩展,对shape不为1的只能保持不变,因为不知道如何变换,会报错

d = torch.randn(1, 32, 1, 1)
print(d.shape)
print(d.expand(4, 32, 14, 14).shape)

输出结果

1
2
torch.Size([1, 32, 1, 1])
torch.Size([4, 32, 14, 14])
函数2:repeat()方法,扩张多少倍
1
2
3
d=torch.randn([1,32,4,5])
print(d.shape)
print(d.repeat(4,32,2,3).shape)

输出结果

1
2
torch.Size([1, 32, 4, 5])
torch.Size([4, 1024, 8, 15])