pytorch中torch.stack()函数用法解读
作者:RealWeakCoder
这篇文章主要介绍了pytorch中torch.stack()函数用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
torch.stack()函数用法
一、基本功能
pytroch官方文档对于这个函数的描述很简略。
只有一句话:
在维度上连接(concatenate)若干个张量。(这些张量形状相同)。
经过代码总结归纳,可以得到stack(tensors,dim=0,out=None)
函数的功能:
将若干个张量在dim维度上连接,生成一个扩维的张量,比如说原来你有若干个2维张量,连接可以得到一个3维的张量。
设待连接张量维度为n,dim取值范围为-n-1~n,这里得提一下为负的意义:-i为倒数第i个维度。
举个例子:
对于2维的待连接张量,-1维即3维,-2维即2维。
上代码:
a=torch.tensor([[1,2,3],[4,5,6]]) b=torch.tensor([[10,20,30],[40,50,60]]) c=torch.tensor([[100,200,300],[400,500,600]]) print(torch.stack([a,b,c],dim=0)) print(torch.stack([a,b,c],dim=1)) print(torch.stack([a,b,c],dim=2)) print(torch.stack([a,b,c],dim=0).size()) print(torch.stack([a,b,c],dim=1).size()) print(torch.stack([a,b,c],dim=2).size()) #输出结果为: tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 10, 20, 30], [ 40, 50, 60]], [[100, 200, 300], [400, 500, 600]]]) tensor([[[ 1, 2, 3], [ 10, 20, 30], [100, 200, 300]], [[ 4, 5, 6], [ 40, 50, 60], [400, 500, 600]]]) tensor([[[ 1, 10, 100], [ 2, 20, 200], [ 3, 30, 300]], [[ 4, 40, 400], [ 5, 50, 500], [ 6, 60, 600]]]) torch.Size([3, 2, 3]) torch.Size([2, 3, 3]) torch.Size([2, 3, 3])
二、规律分析
通过代码运行结果,我们不难发现,stack(tensors,dim=0,out=None)
函数的运行机制可以等价为:
- dim=0时,将tensor在一维上连接,简单来说就是,就是将tensor1,tensor2…tensor n,连接为【tensor1,tensor2… tensor n】(就是在这里产生了扩维)
- dim=1时,将每个tensor的第i行按行连接组成一个新的2维tensor,再将这些新tensor按照dim=0的方式连接。
- dim=2时,将每个tensor的第i行转置后按列连接组成一个新的2维tensor,再将这些新tesnor按照dim=0的方式连接
可以得到一个结论:n维(n>=2)待连接张量按dim=x的方式连接等价于:
- 若x=0,参照上面的规律进行连接
- 若x>0,对每个张量的第一个维度下的张量对应地按照dim=x-1的方式进行连接得到若干个新张量,这些新张量按照dim=0的方式进行连接。
- 很明显,该规律具有递归的特性,x=0,1,2的基础情况已经给出。
注:以上规律是在未看函数实现源码基础上未加证明的猜测。
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。