python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch torch.gather函数

pytorch torch.gather函数的使用

作者:qq_27390023

torch.gather 是 PyTorch 中用于在指定维度上通过索引从源张量中提取元素的函数,它需要输入张量、维度索引和索引张量,示例代码展示了如何使用 torch.gather 从输入张量中按索引提取元素,返回的结果张量形状与索引张量相同

pytorch torch.gather函数

torch.gather 是 PyTorch 中的一个用于从给定维度上按索引取值的函数。

它根据一个索引张量 index,从源张量 input 中收集值,并返回一个新的张量。

torch.gather 常用于需要从张量的特定位置抽取元素的操作。

1. 函数签名

torch.gather(input, dim, index, *, sparse_grad=False, out=None)

2. 工作原理

torch.gatherdim 维度上,通过 index 指定的索引,从 input 中选取元素。

返回的张量的形状与 index 的形状相同。

3. 示例代码

以下是一个简单的示例代码,演示如何使用 torch.gather 函数:

import torch

# 创建一个源张量
input = torch.tensor([[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]])

# 创建一个索引张量
index = torch.tensor([[0, 2, 1],
                      [2, 0, 1],
                      [1, 2, 0]])

# 在 dim=1 维度上使用 gather 函数
result = torch.gather(input, dim=1, index=index)

print("Input Tensor:")
print(input)
print("\nIndex Tensor:")
print(index)
print("\nResult Tensor:")
print(result)

4. 输出结果

Input Tensor:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

Index Tensor:
tensor([[0, 2, 1],
        [2, 0, 1],
        [1, 2, 0]])

Result Tensor:
tensor([[1, 3, 2],
        [6, 4, 5],
        [8, 9, 7]])

5. 解释

在这个例子中:

总结

torch.gather 通过索引在指定维度上提取张量中的元素,是用于基于索引选择数据的有用工具。

函数对批处理数据特别有用,例如在分类任务中提取对应类别的概率或得分。

索引张量的形状必须与源张量在指定维度的形状相匹配,以确保正确的取值操作。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文