pytorch torch.gather函数的使用
作者:qq_27390023
torch.gather 是 PyTorch 中用于在指定维度上通过索引从源张量中提取元素的函数,它需要输入张量、维度索引和索引张量,示例代码展示了如何使用 torch.gather 从输入张量中按索引提取元素,返回的结果张量形状与索引张量相同
pytorch torch.gather函数
torch.gather
是 PyTorch 中的一个用于从给定维度上按索引取值的函数。
它根据一个索引张量 index
,从源张量 input
中收集值,并返回一个新的张量。
torch.gather
常用于需要从张量的特定位置抽取元素的操作。
1. 函数签名
torch.gather(input, dim, index, *, sparse_grad=False, out=None)
input
:输入张量,表示要从中收集元素的源张量。dim
:要收集的维度索引。例如,对于一个二维张量,0 表示沿着行的维度,1 表示沿着列的维度。index
:索引张量,其形状应与input
张量在除了dim
维度之外的其他维度上保持一致。索引张量中的值表示在input
张量对应维度上要收集的元素的索引。out
(可选):输出张量,如果提供,结果将存储在这个张量中。
2. 工作原理
torch.gather
在 dim
维度上,通过 index
指定的索引,从 input
中选取元素。
返回的张量的形状与 index
的形状相同。
3. 示例代码
以下是一个简单的示例代码,演示如何使用 torch.gather
函数:
import torch # 创建一个源张量 input = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 创建一个索引张量 index = torch.tensor([[0, 2, 1], [2, 0, 1], [1, 2, 0]]) # 在 dim=1 维度上使用 gather 函数 result = torch.gather(input, dim=1, index=index) print("Input Tensor:") print(input) print("\nIndex Tensor:") print(index) print("\nResult Tensor:") print(result)
4. 输出结果
Input Tensor:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])Index Tensor:
tensor([[0, 2, 1],
[2, 0, 1],
[1, 2, 0]])Result Tensor:
tensor([[1, 3, 2],
[6, 4, 5],
[8, 9, 7]])
5. 解释
- 输入张量 (
input
) 是一个3x3
的矩阵,每个元素代表一个值。 - 索引张量 (
index
) 指定了要从input
中提取的元素的索引。 - 结果张量 (
result
) 是根据index
从input
中提取的元素形成的张量。
在这个例子中:
- 对于
input
的第一行,index
提取了索引0, 2, 1
对应的元素1, 3, 2
。 - 对于
input
的第二行,index
提取了索引2, 0, 1
对应的元素6, 4, 5
。 - 对于
input
的第三行,index
提取了索引1, 2, 0
对应的元素8, 9, 7
。
总结
torch.gather
通过索引在指定维度上提取张量中的元素,是用于基于索引选择数据的有用工具。
函数对批处理数据特别有用,例如在分类任务中提取对应类别的概率或得分。
索引张量的形状必须与源张量在指定维度的形状相匹配,以确保正确的取值操作。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。