python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch reshape

pytorch中reshape的使用小结

作者:Caven77

本文主要介绍了pytorch中reshape的使用小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

torch.reshape 是 PyTorch 中用于改变张量形状的函数。它的作用是重新安排张量的维度,使其符合指定的形状,但不会改变数据的顺序。

Returns a tensor with the same data and number of elements as input, but with the specified shape.

基本语法

input_reshape = torch.reshape(input, shape)

或者:

input_reshape = input.reshape(shape)

例子

import torch

# 创建一个形状为 (4, 3) 的张量
x = torch.arange(12).reshape(4, 3)
print(x)

输出:

tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

你可以将这个张量重新调整成其他形状,比如将它变成一个 (2, 6) 的张量:

y = x.reshape(2, 6)
print(y)

输出:

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])

特别的用法

使用1自动推算维度:

如果你不确定某个维度应该是多少,可以使用 -1 来自动推算。例如,如果你只知道张量的总元素数和某些维度,可以让 PyTorch 自动计算某个维度。

A single dimension may be -1, in which case it’s inferred from the remaining dimensions and the number of elements in input.

z = x.reshape(2, -1)  # PyTorch 会自动推算第二维度的大小
print(z)

输出:

tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])

torch.reshape(x, (-1,))将张量展平成一维(即向量)等效于x.flatten()

注意:这跟 (1,-1)有区别,(-1,)是展平成一维向量,
而(1,-1)是1*n,这被视作二维的(一个维度是1个数,另一个维度是n个数)

例子:

x = torch.tensor([[0, 1],
                  [2, 3]])

x_vec = x.reshape(-1,)
print(x_vec)

x_matrix = x.reshape(1,-1)
print(x_matrix)

输出:

tensor([0, 1, 2, 3])
tensor([[0, 1, 2, 3]])

注意事项

保持元素数量一致:

使用 reshape 时,新形状的元素数量必须和原形状一致。例如,原来是 (4, 3),总共有 12 个元素,不能reshape成 (3, 5) 因为会丢失元素。

# 错误的 reshape
# x.reshape(3, 5)  # 会抛出错误,无法reshape

总之,要接收返回的tensor

When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.

到此这篇关于pytorch中reshape的使用小结的文章就介绍到这了,更多相关pytorch reshape内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文