python pytorch中.view()函数的用法解读
作者:Dust_Evc
这篇文章主要介绍了python pytorch中.view()函数的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
python pytorch中.view()函数
在使用pytorch定义神经网络时,经常会看到类似如下的.view()用法,这里对其用法做出讲解与演示。
普通用法 (手动调整size)
view()相当于reshape、resize,重新调整Tensor的形状。
import torch a1 = torch.arange(0,16) print(a1)
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
a2 = a1.view(8, 2) a3 = a1.view(2, 8) a4 = a1.view(4, 4) print(a2) print(a3) print(a4)
tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11], [12, 13], [14, 15]]) tensor([[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15]]) tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]])
特殊用法:参数-1 (自动调整size)
view中一个参数定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。
import torch a1 = torch.arange(0,16) print(a1)
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
a2 = a1.view(-1, 16) a3 = a1.view(-1, 8) a4 = a1.view(-1, 4) a5 = a1.view(-1, 2) a6 = a1.view(4*4, -1) a7 = a1.view(1*4, -1) a8 = a1.view(2*4, -1) print(a2) print(a3) print(a4) print(a5) print(a6) print(a7) print(a8)
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]) tensor([[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15]]) tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]]) tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11], [12, 13], [14, 15]]) tensor([[ 0], [ 1], [ 2], [ 3], [ 4], [ 5], [ 6], [ 7], [ 8], [ 9], [10], [11], [12], [13], [14], [15]]) tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]]) tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11], [12, 13], [14, 15]])
python中view()函数怎么用
初学者在使用pytorch框架定义神经网络时,经常会在代码中看到:
这样的用法。
view()的作用相当于numpy中的reshape,重新定义矩阵的形状。
例1 普通用法:
import torch v1 = torch.range(1, 16) v2 = v1.view(4, 4)
其中v1为1*16大小的张量,包含16个元素。v2为4*4大小的张量,同样包含16个元素。注意view前后的元素个数要相同,不然会报错。
例2 参数使用-1
import torch v1 = torch.range(1, 16) v2 = v1.view(-1, 4)
和图例中的用法一样,view中一个参数定为-1,代表动态调整这个维度上的元素个数,以保证元素的总数不变。因此两个例子的结果是相同的。
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。