python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch torch_scatter.scatter_max()

pytorch之torch_scatter.scatter_max()用法

作者:A2333fun

这篇文章主要介绍了pytorch之torch_scatter.scatter_max()用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch_scatter.scatter_max()

torch_scatter.scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None)

from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
'''src根据index进行分组'''
out, argmax = scatter_max(src, index, out=out)
print(out)
print(argmax)

输出

tensor([[0., 0., 4., 3., 2., 0.],
        [2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1,  3,  4,  0,  1],
        [ 1,  4,  3, -1, -1, -1]])

解释

torch_scatter.scatter()使用

1. 参数

具体来讲,scatter函数的作用就是将index中相同索引对应位置的src元素进行某种方式的操作,例如 sum mean 等,然后将这些操作结果按照索引顺序进行拼接。

下面我用具体的例子来进行讲解。

2. 示例

2.1 简单示例

首先初始化src和index:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)

接着使用scatter函数:

out = scatter(src, index, dim=0, reduce='mean')

我们观察 index=[0, 0, 1] ,第0个位置和第1个位置都为0,第2个位置为1。也就是说,我们需要将src中第0个元素和第1个元素求平均变成一个元素,然后第2个元素求mean也就是本身为一个元素。如果 index=[1, 0, 0] ,则意味着我们需要将src中第1个元素和第2个元素求平均变成一个元素,而第0个元素保持不变。

那么src中第几个元素到底是如何定义的呢?这就需要用到 dim 参数了。

dim=0 意味着我们需要对src的维度0进行操作:

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])

即src中第0个元素为 [1, 2, 3] ,第1个元素为 [4, 5, 6] ,第2个元素为 [7, 8, 9]

而如果 dim=1 ,则第0个元素为 [1, 4, 7] ,第1个元素为 [2, 5, 8] ,第2个元素为 [3, 6, 9]

因此,如果有以下代码:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # (3, 3)
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')

那么我们就应该将src中的第0个元素为 [1, 2, 3] 和第1个元素为 [4, 5, 6] 求平均为 [2.5, 3.5, 4.5] ,然后第2个元素 [7, 8, 9] 保持不变,即:

tensor([[2.5000, 3.5000, 4.5000],
        [7.0000, 8.0000, 9.0000]])

2.2 顺序问题

上面的例子中 index=[0, 0, 1] ,最后结果是将src中第0个元素和第1个元素求平均放到了位置0,然后src中第2个元素保持不变放到了位置1。

如果 index=[1, 1, 0] ,结果为:

tensor([[7.0000, 8.0000, 9.0000],
        [2.5000, 3.5000, 4.5000]])

可以发现,上述结果是将src中第2个元素 [7, 8, 9] 保持不变放到了位置0,然后将src中第0个元素 [1, 2, 3] 和第1个元素 [4, 5, 6] 求平均保持不变放到了位置1。

也就是说,无论index怎么变化,都是优先将index中0对应位置的操作结果进行放置。

2.3 维度问题

如果src的维度为(4, 3),而我们需要对 dim=0 操作,也就是一共有四个元素,那么index的长度应该为4,即以下操作是不合法的:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

报错为:

RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 0.  Target sizes: [4, 3].  Tensor sizes: [3, 1]

正确做法应该是:

src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])  # (4, 3)
index = torch.tensor([1, 1, 0, 2], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)

输出为:

tensor([[ 7.0000,  8.0000,  9.0000],
        [ 2.5000,  3.5000,  4.5000],
        [10.0000, 11.0000, 12.0000]])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文