python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch中torch.stack()函数

pytorch中torch.stack()函数用法解读

作者:RealWeakCoder

这篇文章主要介绍了pytorch中torch.stack()函数用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch.stack()函数用法

一、基本功能

pytroch官方文档对于这个函数的描述很简略。

只有一句话:

在维度上连接(concatenate)若干个张量。(这些张量形状相同)。

经过代码总结归纳,可以得到stack(tensors,dim=0,out=None)函数的功能:

将若干个张量在dim维度上连接,生成一个扩维的张量,比如说原来你有若干个2维张量,连接可以得到一个3维的张量。

设待连接张量维度为n,dim取值范围为-n-1~n,这里得提一下为负的意义:-i为倒数第i个维度。

举个例子:

对于2维的待连接张量,-1维即3维,-2维即2维。

上代码:

a=torch.tensor([[1,2,3],[4,5,6]])
b=torch.tensor([[10,20,30],[40,50,60]])
c=torch.tensor([[100,200,300],[400,500,600]])
print(torch.stack([a,b,c],dim=0))
print(torch.stack([a,b,c],dim=1))
print(torch.stack([a,b,c],dim=2))
print(torch.stack([a,b,c],dim=0).size())
print(torch.stack([a,b,c],dim=1).size())
print(torch.stack([a,b,c],dim=2).size())
#输出结果为:
tensor([[[  1,   2,   3],
         [  4,   5,   6]],

        [[ 10,  20,  30],
         [ 40,  50,  60]],

        [[100, 200, 300],
         [400, 500, 600]]])
tensor([[[  1,   2,   3],
         [ 10,  20,  30],
         [100, 200, 300]],

        [[  4,   5,   6],
         [ 40,  50,  60],
         [400, 500, 600]]])
tensor([[[  1,  10, 100],
         [  2,  20, 200],
         [  3,  30, 300]],

        [[  4,  40, 400],
         [  5,  50, 500],
         [  6,  60, 600]]])
torch.Size([3, 2, 3])
torch.Size([2, 3, 3])
torch.Size([2, 3, 3])

二、规律分析

通过代码运行结果,我们不难发现,stack(tensors,dim=0,out=None)函数的运行机制可以等价为:

可以得到一个结论:n维(n>=2)待连接张量按dim=x的方式连接等价于:

注:以上规律是在未看函数实现源码基础上未加证明的猜测。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文