PyTorch中flatten() 函数的用法实例小结
作者:纽约恋情
在PyTorch中,flatten函数的作用是将一个多维的张量转换为一维的向量,它可以将任意形状的张量转换为一维,而不需要指定转换后的大小,这篇文章主要介绍了PyTorch中flatten() 函数的用法,需要的朋友可以参考下
一. 用法
Flatten层主要是用来将输入“压平”,即把多维的输入一维化,用在卷积层到全连接层的过渡。其不会影响batch的大小,可以理解为把高纬度的数组按照x轴或者y轴进行拉伸,变成一维的数组。
二. 参数
1.start_dim(可选参数):指定从哪个维度开始展平张量。默认情况下,start_dim
被设置为0,表示从第一个维度(通常是批大小)开始展平。如果设置为其他整数值,则会从指定的维度开始展平。
2.end_dim(可选参数):指定在哪个维度结束展平张量。默认情况下,end_dim
被设置为-1,表示展平直到最后一个维度。如果设置为其他整数值,则会在指定的维度结束展平。
三. 实例
(1). 首先随机定义一个满足正态分布的(2,3,4)的数据x
import torch x = torch.randn(2,3,4) print(x) x = x.flatten(0) print(x) ------------------------------------ tensor([[[ 0.1281, 1.6878, 0.2301, -0.0721], [ 1.2374, -0.6929, 1.1186, 0.4372], [ 0.5122, 1.4653, -0.1673, 0.7258]], [[ 0.2772, -1.9994, -1.2284, 0.2764], [-0.0451, -0.9195, 0.5749, 0.1942], [ 0.8539, -0.0434, -0.7313, 0.0234]]]) tensor([ 0.1281, 1.6878, 0.2301, -0.0721, 1.2374, -0.6929, 1.1186, 0.4372, 0.5122, 1.4653, -0.1673, 0.7258, 0.2772, -1.9994, -1.2284, 0.2764, -0.0451, -0.9195, 0.5749, 0.1942, 0.8539, -0.0434, -0.7313, 0.0234]) import torch x = torch.randn(2,3,4) print(x) x = x.flatten(0) print(x) ------------------------------------ tensor([[[ 0.1281, 1.6878, 0.2301, -0.0721], [ 1.2374, -0.6929, 1.1186, 0.4372], [ 0.5122, 1.4653, -0.1673, 0.7258]], [[ 0.2772, -1.9994, -1.2284, 0.2764], [-0.0451, -0.9195, 0.5749, 0.1942], [ 0.8539, -0.0434, -0.7313, 0.0234]]]) tensor([ 0.1281, 1.6878, 0.2301, -0.0721, 1.2374, -0.6929, 1.1186, 0.4372, 0.5122, 1.4653, -0.1673, 0.7258, 0.2772, -1.9994, -1.2284, 0.2764, -0.0451, -0.9195, 0.5749, 0.1942, 0.8539, -0.0434, -0.7313, 0.0234])
此时x的维度是2×3×4=24,x = flatten(0) 和 x = flatten()的结果相同。
(2).
import torch x = torch.randn(2,3,4) print(x) x = x.flatten(1) print(x) =========================================== tensor([[[-0.7137, -0.0859, -1.5284, 0.7284], [ 0.8425, 0.3606, 1.7639, 0.1848], [ 0.4040, -1.6575, 1.9134, -1.0787]], [[ 0.6981, 1.3494, -0.5817, -1.1824], [-0.4972, 0.4179, 2.1742, -0.2462], [ 0.2429, -1.9315, -0.3497, 0.7190]]]) tensor([[-0.7137, -0.0859, -1.5284, 0.7284, 0.8425, 0.3606, 1.7639, 0.1848, 0.4040, -1.6575, 1.9134, -1.0787], [ 0.6981, 1.3494, -0.5817, -1.1824, -0.4972, 0.4179, 2.1742, -0.2462, 0.2429, -1.9315, -0.3497, 0.7190]])
此时x是从1维度开始展开,最后的x维度为(2,3×4),也就是(2,12)
注意:start_dim
和end_dim
参数的取值范围应该在 -x.dim() <= start_dim <= end_dim < x.dim()
之间。
到此这篇关于PyTorch中flatten() 函数的用法的文章就介绍到这了,更多相关PyTorch flatten() 函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!