python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > PyTorch tensor.squeeze()

PyTorch中tensor.squeeze() 使用小结

作者:Geoking.

本文主要介绍了PyTorch中tensor.squeeze()使用小结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

在深度学习中经常会处理各种形状(shape)复杂的张量(tensor)。
有时候,模型的输入或输出会多出一些没用的“维度”,例如 (1, 3, 1, 224, 224)
这时,PyTorch 提供了一个非常实用的函数 —— torch.squeeze(),可以帮我们轻松去除大小为 1 的维度

本文将带你从入门到实战,彻底掌握 squeeze() 的使用方法与常见坑。🚀

一、函数简介

官方定义

torch.squeeze(input, dim=None) → Tensor

作用:

返回一个新的张量,去掉所有大小为 1 的维度。
如果指定了 dim 参数,则只会在那个维度上去除大小为 1 的维度。

squeeze 本身英文释义如下:

二、为什么需要squeeze()?

在深度学习中,模型输入输出的维度往往需要严格匹配。
但是,在数据加载或卷积操作之后,可能会出现一些“冗余维度”。

举个例子:

import torch

x = torch.randn(1, 3, 1, 4)
print(x.shape)

输出:

torch.Size([1, 3, 1, 4])

可以看到,这个张量的第 0 和第 2 维都是大小为 1 的“空维度”。
这些维度不会存储实际信息,但可能会导致维度不匹配错误。

这时,我们就可以用:

x.squeeze()

输出:

torch.Size([3, 4])

✅ 所有大小为 1 的维度都被自动去掉!

三、函数语法与参数说明

参数类型说明
inputTensor输入张量
dimint, 可选指定要压缩的维度
返回值Tensor新张量(共享存储,不复制数据)

四、示例讲解

1、去除所有大小为 1 的维度

x = torch.randn(1, 3, 1, 4, 1)
print("原形状:", x.shape)
y = torch.squeeze(x)
print("压缩后:", y.shape)

输出:

原形状: torch.Size([1, 3, 1, 4, 1])
压缩后: torch.Size([3, 4])

说明:
squeeze() 去掉了所有维度为 1 的轴。

2、指定某个维度压缩

有时候我们不想去掉所有维度,只想处理特定的一个。

x = torch.randn(1, 3, 1, 4)
print("原形状:", x.shape)
y = torch.squeeze(x, dim=0)
print("压缩后:", y.shape)

输出:

原形状: torch.Size([1, 3, 1, 4])
压缩后: torch.Size([3, 1, 4])

👉 只去掉了第 0 维,因为它的大小是 1。
其他维度保持不变。

3、如果指定的维度不是 1,会怎样?

x = torch.randn(2, 1, 3)
y = torch.squeeze(x, dim=0)
print(y.shape)

输出:

torch.Size([2, 1, 3])

没有任何变化,因为第 0 维的大小是 2,不是 1
squeeze() 只会压缩大小为 1 的维度,不会报错。

五、与unsqueeze()的关系

如果说 squeeze() 是“去掉维度”,
unsqueeze() 就是“增加维度”。

x = torch.tensor([1, 2, 3])
print(x.shape)  # torch.Size([3])

y = x.unsqueeze(0)
print(y.shape)  # torch.Size([1, 3])

z = y.squeeze(0)
print(z.shape)  # torch.Size([3])

unsqueeze()squeeze() 是一对反操作。
一个增加维度,一个去除维度。

六、常见应用场景

1、数据集加载时去掉多余维度

# 读取图片后通常是 (1, H, W)
img = torch.randn(1, 224, 224)
img = img.squeeze(0)
print(img.shape)  # torch.Size([224, 224])

2、模型输出后去掉 batch 维度

# 例如分类模型输出 [1, num_classes]
output = torch.randn(1, 10)
pred = output.squeeze(0)
print(pred.shape)  # torch.Size([10])

3、多维卷积层结果调整

Conv2dLSTM 等层输出中,有时需要将 [batch, seq_len, 1] 变成 [batch, seq_len]

out = torch.randn(32, 100, 1)
out = out.squeeze(-1)
print(out.shape)  # torch.Size([32, 100])

⚠️ 七、注意事项与坑点

问题说明
❌ 误删维度默认不传 dim 会删除所有大小为 1 的维度,可能导致形状变化过多
✅ 建议当只想去掉某个维度时,一定要写 dim 参数
⚙️ 内存共享squeeze() 返回的张量与原张量共享内存,不会复制数据

八、扩展:与 NumPy 对比

PyTorch 的 squeeze() 和 NumPy 的 numpy.squeeze() 功能几乎一致。

import numpy as np

a = np.random.randn(1, 3, 1, 4)
print(a.shape)  # (1, 3, 1, 4)
print(a.squeeze().shape)  # (3, 4)

如果熟悉 NumPy 的用法,PyTorch 中也能无缝衔接。

九、总结

功能说明
函数torch.squeeze(input, dim=None)
作用删除大小为 1 的维度
参数dim:指定要压缩的维度(可选)
返回新张量(共享内存)
反操作unsqueeze()
常用场景模型输出处理、数据预处理、维度调整

十、参考资料

到此这篇关于PyTorch中tensor.squeeze() 使用小结的文章就介绍到这了,更多相关PyTorch tensor.squeeze()内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文