torch模块下的数学操作符
1 . torch.numel() 返回一个tensor变量内所有元素个数,可以理解为矩阵内元素的个数
2 . torch.squeeze()
对于tensor变量进行维度压缩,去除维数为1的的维度。例如一矩阵维度为A*1*B*C*1*D,通过squeeze()返回向量的维度为A*B*C*D。squeeze(a),表示将a的维数位1的维度删掉,squeeze(a,N)表示,如果第N维维数为1,则压缩去掉,否则a矩阵不变
3 . torch.unsqueeze()
是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加
4 . torch.stack(sequence, dim=0,
out=None),做tensor的拼接。sequence表示Tensor列表,dim表示拼接的维度,注意这个函数和concatenate是不同的,torch的concatenate函数是torch.cat,是在已有的维度上拼接,而stack是建立一个新的维度,然后再在该纬度上进行拼接。
例子:
import torch a=torch.Tensor([[1,2,3],[4,5,6]]) b=torch.Tensor(
[[7,8,9],[10,11,12]]) d=torch.stack( (a,b) ,dim = 1) print(d)
输出:
tensor([[[ 1., 2., 3.], [ 7., 8., 9.]], [[ 4., 5., 6.], [ 10., 11., 12.]]])
5 .
expand_as(a)这是tensor变量的一个内置方法,如果使用b.expand_as(a)就是将b进行扩充,扩充到a的维度,需要说明的是a的低维度需要比b大,例如b的shape是3*1,如果a的shape是3*2不会出错,但是是2*2就会报错了
热门工具 换一换