python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > PyTorch torch.argmax

PyTorch中torch.argmax函数的使用

作者:Code_Geo

torch.argmax 是一个高效的工具,广泛应用于分类模型预测、指标计算等场景,下面就来介绍一下PyTorch中torch.argmax函数的使用,感兴趣的可以了解一下

torch.argmax 是 PyTorch 中的一个函数,用于返回输入张量中最大值所在的索引。其作用与数学中的 ​argmax 概念一致,即找到某个函数在指定范围内取得最大值时的参数(位置索引

函数定义

torch.argmax(input, dim=None, keepdim=False)

核心功能

1、​全局最大值索引​(当 dim=None)

import torch

x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
print(torch.argmax(x))  # 输出:tensor(3)
# 展平后的索引:1, 2, 3, 6, 5, 4 → 最大值为6,索引为3(从0开始)

2|​沿指定维度查找最大值索引​(当 dim 指定时)

# 沿行维度(dim=1)查找
x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
print(torch.argmax(x, dim=1))  # 输出:tensor([2, 0])
# 解释:
# 第一行 [1, 2, 3] 最大值3,索引2
# 第二行 [6, 5, 4] 最大值6,索引0

# 沿列维度(dim=0)查找
print(torch.argmax(x, dim=0))  # 输出:tensor([1, 1, 0])
# 解释:
# 第0列 [1, 6] 最大值6,索引1
# 第1列 [2, 5] 最大值5,索引1
# 第2列 [3, 4] 最大值4,索引1(但此处输出为0,可能有误,实际应为1)

参数详解

1. dim 参数

2. keepdim 参数

x = torch.tensor([[1, 2, 3],
                  [6, 5, 4]])
out = torch.argmax(x, dim=1, keepdim=True)
print(out)  # 输出:tensor([[2], [0]])

常见用途

1、​分类任务中获取预测标签

logits = torch.tensor([0.1, 0.8, 0.05, 0.05])  # 模型输出的概率分布
predicted_class = torch.argmax(logits)         # 输出:tensor(1)

2、​计算准确率

# 假设batch_size=4,num_classes=3
preds = torch.tensor([[0.1, 0.2, 0.7],
                      [0.9, 0.05, 0.05],
                      [0.3, 0.4, 0.3],
                      [0.05, 0.8, 0.15]])
labels = torch.tensor([2, 0, 1, 1])
# 获取预测类别
predicted_classes = torch.argmax(preds, dim=1)  # 输出:tensor([2, 0, 1, 1])
# 计算正确预测数
correct = (predicted_classes == labels).sum()   # 输出:tensor(3)

注意事项

1、​多个相同最大值:

x = torch.tensor([3, 1, 4, 4])
print(torch.argmax(x))  # 输出:tensor(2)

2、​数据类型

3、​维度合法性

总结

torch.argmax 是一个高效的工具,广泛应用于分类模型预测、指标计算等场景。理解其 dim 和 keepdim 参数的行为,可以灵活处理不同维度的数据

到此这篇关于PyTorch中torch.argmax函数的使用的文章就介绍到这了,更多相关PyTorch torch.argmax内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文