PyTorch中tensor.squeeze() 使用小结
作者:Geoking.
在深度学习中经常会处理各种形状(shape)复杂的张量(tensor)。
有时候,模型的输入或输出会多出一些没用的“维度”,例如 (1, 3, 1, 224, 224)。
这时,PyTorch 提供了一个非常实用的函数 —— torch.squeeze(),可以帮我们轻松去除大小为 1 的维度。
本文将带你从入门到实战,彻底掌握 squeeze() 的使用方法与常见坑。🚀
一、函数简介
官方定义
torch.squeeze(input, dim=None) → Tensor
作用:
返回一个新的张量,去掉所有大小为 1 的维度。
如果指定了 dim 参数,则只会在那个维度上去除大小为 1 的维度。
squeeze 本身英文释义如下:

二、为什么需要squeeze()?
在深度学习中,模型输入输出的维度往往需要严格匹配。
但是,在数据加载或卷积操作之后,可能会出现一些“冗余维度”。
举个例子:
import torch x = torch.randn(1, 3, 1, 4) print(x.shape)
输出:
torch.Size([1, 3, 1, 4])
可以看到,这个张量的第 0 和第 2 维都是大小为 1 的“空维度”。
这些维度不会存储实际信息,但可能会导致维度不匹配错误。
这时,我们就可以用:
x.squeeze()
输出:
torch.Size([3, 4])
✅ 所有大小为 1 的维度都被自动去掉!
三、函数语法与参数说明
| 参数 | 类型 | 说明 |
|---|---|---|
| input | Tensor | 输入张量 |
| dim | int, 可选 | 指定要压缩的维度 |
| 返回值 | Tensor | 新张量(共享存储,不复制数据) |
四、示例讲解
1、去除所有大小为 1 的维度
x = torch.randn(1, 3, 1, 4, 1)
print("原形状:", x.shape)
y = torch.squeeze(x)
print("压缩后:", y.shape)
输出:
原形状: torch.Size([1, 3, 1, 4, 1])
压缩后: torch.Size([3, 4])
说明:squeeze() 去掉了所有维度为 1 的轴。
2、指定某个维度压缩
有时候我们不想去掉所有维度,只想处理特定的一个。
x = torch.randn(1, 3, 1, 4)
print("原形状:", x.shape)
y = torch.squeeze(x, dim=0)
print("压缩后:", y.shape)
输出:
原形状: torch.Size([1, 3, 1, 4])
压缩后: torch.Size([3, 1, 4])
👉 只去掉了第 0 维,因为它的大小是 1。
其他维度保持不变。
3、如果指定的维度不是 1,会怎样?
x = torch.randn(2, 1, 3) y = torch.squeeze(x, dim=0) print(y.shape)
输出:
torch.Size([2, 1, 3])
没有任何变化,因为第 0 维的大小是 2,不是 1。squeeze() 只会压缩大小为 1 的维度,不会报错。
五、与unsqueeze()的关系
如果说 squeeze() 是“去掉维度”,
那 unsqueeze() 就是“增加维度”。
x = torch.tensor([1, 2, 3]) print(x.shape) # torch.Size([3]) y = x.unsqueeze(0) print(y.shape) # torch.Size([1, 3]) z = y.squeeze(0) print(z.shape) # torch.Size([3])
✅ unsqueeze() 与 squeeze() 是一对反操作。
一个增加维度,一个去除维度。
六、常见应用场景
1、数据集加载时去掉多余维度
# 读取图片后通常是 (1, H, W) img = torch.randn(1, 224, 224) img = img.squeeze(0) print(img.shape) # torch.Size([224, 224])
2、模型输出后去掉 batch 维度
# 例如分类模型输出 [1, num_classes] output = torch.randn(1, 10) pred = output.squeeze(0) print(pred.shape) # torch.Size([10])
3、多维卷积层结果调整
在 Conv2d、LSTM 等层输出中,有时需要将 [batch, seq_len, 1] 变成 [batch, seq_len]。
out = torch.randn(32, 100, 1) out = out.squeeze(-1) print(out.shape) # torch.Size([32, 100])
⚠️ 七、注意事项与坑点
| 问题 | 说明 |
|---|---|
| ❌ 误删维度 | 默认不传 dim 会删除所有大小为 1 的维度,可能导致形状变化过多 |
| ✅ 建议 | 当只想去掉某个维度时,一定要写 dim 参数 |
| ⚙️ 内存共享 | squeeze() 返回的张量与原张量共享内存,不会复制数据 |
八、扩展:与 NumPy 对比
PyTorch 的 squeeze() 和 NumPy 的 numpy.squeeze() 功能几乎一致。
import numpy as np a = np.random.randn(1, 3, 1, 4) print(a.shape) # (1, 3, 1, 4) print(a.squeeze().shape) # (3, 4)
如果熟悉 NumPy 的用法,PyTorch 中也能无缝衔接。
九、总结
| 功能 | 说明 |
|---|---|
| 函数 | torch.squeeze(input, dim=None) |
| 作用 | 删除大小为 1 的维度 |
| 参数 | dim:指定要压缩的维度(可选) |
| 返回 | 新张量(共享内存) |
| 反操作 | unsqueeze() |
| 常用场景 | 模型输出处理、数据预处理、维度调整 |
十、参考资料

到此这篇关于PyTorch中tensor.squeeze() 使用小结的文章就介绍到这了,更多相关PyTorch tensor.squeeze()内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
您可能感兴趣的文章:
- pytorch中tensor.expand()和tensor.expand_as()函数详解
- PyTorch中 tensor.detach() 和 tensor.data 的区别详解
- pytorch中Tensor.to(device)和model.to(device)的区别及说明
- pytorch中torch.max和Tensor.view函数用法详解
- PyTorch中tensor.backward()函数的详细介绍及功能实现
- PyTorch中关于tensor.repeat()的使用
- PyTorch中 tensor.detach() 和 tensor.data 的区别解析
- pytorch中Tensor.new()的使用解析
- pytorch中函数tensor.numpy()的数据类型解析
