python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > torch.flatten()函数及x=x.view()函数理解

关于torch.flatten()函数及x=x.view()函数的理解

作者:浩瀚之水_csdn

这篇文章主要介绍了关于torch.flatten()函数及x=x.view()函数的理解,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

一、x = x.view()

x = x.view(x.size(0), -1)

在PyTorch中,x.view(x.size(0), -1)是一种常用的操作,用于改变张量(Tensor)的形状而不改变其数据。

这里的x是一个多维张量,而.view()函数是用来重新塑形这个张量的,同时保持其元素的总数不变。

具体来说,x.view(x.size(0), -1)的含义是:

因此,x.view(x.size(0), -1)的作用是将张量x重新塑形为一个二维张量,其中第一维的大小保持不变(即原始张量的第一个维度的大小),而第二维的大小则自动调整,以包含所有剩余的元素。

这种操作在需要将多维数据“展平”为二维数据以进行某些操作(如全连接层)时非常有用。

例如,如果x是一个形状为(64, 3, 28, 28)的张量(通常表示一个包含64个图像,每个图像有3个颜色通道,每个通道的大小为28x28像素的数据集),那么x.view(x.size(0), -1)将会把x重新塑形为一个形状为(64, 3*28*28)的张量,其中每个样本都被展平成了一个长向量。

二、torch.flatten()函数

x = torch.flatten(x, start_dim=0, end_dim=2)

x = torch.flatten(x, 0)

当你使用 x = torch.flatten(x, 0) 时,这里的 0start_dim 参数的值,而 end_dim 参数仍然默认为 -1。这意呀着展平操作将从张量 x 的第一个维度(索引为0的维度)开始,并且一直进行到张量的最后一个维度。

然而,由于 start_dim 被设置为0,并且 end_dim 默认为 -1,实际上这会将整个张量 x 完全展平为一个一维张量。换句话说,无论原始张量 x 的形状如何,调用 torch.flatten(x, 0) 后,x 将变成一个一维张量,其长度等于原始张量中所有元素的总数。

例如,如果原始张量 x 的形状是 (a, b, c, d),那么调用 x = torch.flatten(x, 0) 后,x 的新形状将是 (a*b*c*d,),即一个包含 a*b*c*d 个元素的一维张量。

这种完全展平的操作在需要将多维数据转换为适合某些特定操作(如完全连接层的前馈传播)的一维形式时非常有用。然而,它也意味着你丢失了原始数据的形状信息,除非你在其他地方记录了这些信息或者你的操作不需要保留这些形状信息。

x = torch.flatten(x, 1)

在PyTorch中,torch.flatten(x, start_dim=0, end_dim=-1)函数用于将张量x在指定的维度范围内展平(或扁平化),而不改变其数据。这里的start_dim是开始展平的维度(包含该维度),end_dim是结束展平的维度(不包含该维度),默认情况下end_dim为-1,即最后一个维度。

当你使用x = torch.flatten(x, 1)时,你告诉PyTorch从第二个维度(索引为1,因为索引是从0开始的)开始,一直到最后一个维度,将所有的这些维度都展平成一个维度。这意味着,如果x是一个多维张量,那么除了第一个维度之外的所有维度都将被合并成一个维度。

例如,如果x的形状是(64, 3, 28, 28)(代表64个图像,每个图像有3个颜色通道,每个通道的大小为28x28像素),那么x = torch.flatten(x, 1)将会把x展平成一个形状为(64, 3*28*28)的张量。这里,第一个维度(样本数64)保持不变,而剩下的三个维度(3, 28, 28)被合并成了一个维度。

这种操作在处理图像数据时特别有用,尤其是在需要将图像数据传递给全连接层之前,因为全连接层通常期望输入是二维的(尽管在实践中,通常会先通过一个或多个卷积层来处理图像数据)。通过展平操作,你可以将多维的图像数据转换成二维的形式,以便进行后续处理。

x = torch.flatten(x, 2)

在PyTorch中,torch.flatten(input, start_dim=0, end_dim=-1) 函数用于将多维张量(tensor)展平(flatten)为一维张量,但你可以通过指定start_dimend_dim参数来控制从哪一维度开始展平,以及在哪一维度结束(不包括该维度)。这意味着你可以保留张量的某些维度不变,而将其他维度展平。

对于你的代码 x = torch.flatten(x, 2),这里:

因此,torch.flatten(x, 2) 的意思是从张量x的第3维(因为索引从0开始)开始,将之后的所有维度都展平成一个维度。如果x的形状是例如 (a, b, c, d, e),那么torch.flatten(x, 2)之后,x的形状将变为 (a, b, c*d*e)。这里,ab维度保持不变,而cde三个维度被合并成了一个新的维度。

这种操作在处理多维数据时非常有用,特别是当你需要将一部分数据的维度保持不变,而将其他部分数据“展平”以便于后续处理(如全连接层处理)时。

三、示例

import torch

A = torch.tensor([[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[13,14,15,16],[17,18,19,20],[21,22,23,24]]])
# print(A.size)
print(A.shape)

B = torch.flatten(A,1)
print(B.shape)
# print(B)

C = torch.flatten(A,0,1)
print(C.shape)
# print(C)

D = torch.flatten(A,2)
print(D.shape)
# print(D)

E = torch.flatten(A,0)
print(E.shape)
# print(E)

F = A.view(A.size(0), -1)
print(F.shape)
# print(F)

G = A.view(A.size(0), -1, 1)
print(G.shape)

输出:

torch.Size([2, 3, 4])
torch.Size([2, 12])
torch.Size([6, 4])
torch.Size([2, 3, 4])
torch.Size([24])
torch.Size([2, 12])
torch.Size([2, 12, 1])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文