0%

Pytorch之cat()函数



写在前面:

参考文章:

  1. https://www.cnblogs.com/zhaoyingjie/p/14636468.html

1. cat函数

cat是concatnate的意思:拼接,联系在一起。

先说cat( )的普通用法

如果我们有两个tensor是A和B,想把他们拼接在一起,需要如下操作:

1
2
C = torch.cat( (A,B),0 )  #按维数0拼接(竖着拼)
C \= torch.cat( (A,B),1 ) #按维数1拼接(横着拼)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
\>>> import torch
\>>> A=torch.ones(2,3) #2x3的张量(矩阵)
>>> A
tensor(\[\[ 1., 1., 1.\],
\[ 1., 1., 1.\]\])
\>>> B=2\*torch.ones(4,3) #4x3的张量(矩阵)
>>> B
tensor(\[\[ 2., 2., 2.\],
\[ 2., 2., 2.\],
\[ 2., 2., 2.\],
\[ 2., 2., 2.\]\])
\>>> C=torch.cat((A,B),0) #按维数0(行)拼接
>>> C
tensor(\[\[ 1., 1., 1.\],
\[ 1., 1., 1.\],
\[ 2., 2., 2.\],
\[ 2., 2., 2.\],
\[ 2., 2., 2.\],
\[ 2., 2., 2.\]\])
\>>> C.size()
torch.Size(\[6, 3\])
\>>> D=2\*torch.ones(2,4) #2x4的张量(矩阵)
>>> C=torch.cat((A,D),1)#按维数1(列)拼接
>>> C
tensor(\[\[ 1., 1., 1., 2., 2., 2., 2.\],
\[ 1., 1., 1., 2., 2., 2., 2.\]\])
\>>> C.size()
torch.Size(\[2, 7\])