pytorch之torch_scatter.scatter_max()用法
作者:A2333fun
torch_scatter.scatter_max()
torch_scatter.scatter_max(src, index, dim=-1, out=None, dim_size=None, fill_value=None)
- 根据index将src分组,求每一组中的最大值输出到out
- dim是维度
from torch_scatter import scatter_max src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) out = src.new_zeros((2, 6)) '''src根据index进行分组''' out, argmax = scatter_max(src, index, out=out) print(out) print(argmax)
输出
tensor([[0., 0., 4., 3., 2., 0.],
[2., 4., 3., 0., 0., 0.]])
tensor([[-1, -1, 3, 4, 0, 1],
[ 1, 4, 3, -1, -1, -1]])
解释
torch_scatter.scatter()使用
1. 参数
具体来讲,scatter函数的作用就是将index中相同索引对应位置的src元素进行某种方式的操作,例如 sum
、 mean
等,然后将这些操作结果按照索引顺序进行拼接。
下面我用具体的例子来进行讲解。
2. 示例
2.1 简单示例
首先初始化src和index:
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # (3, 3) index = torch.tensor([0, 0, 1], dtype=torch.int64)
接着使用scatter函数:
out = scatter(src, index, dim=0, reduce='mean')
我们观察 index=[0, 0, 1]
,第0个位置和第1个位置都为0,第2个位置为1。也就是说,我们需要将src中第0个元素和第1个元素求平均变成一个元素,然后第2个元素求mean也就是本身为一个元素。如果 index=[1, 0, 0]
,则意味着我们需要将src中第1个元素和第2个元素求平均变成一个元素,而第0个元素保持不变。
那么src中第几个元素到底是如何定义的呢?这就需要用到 dim
参数了。
dim=0
意味着我们需要对src的维度0进行操作:
tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
即src中第0个元素为 [1, 2, 3]
,第1个元素为 [4, 5, 6]
,第2个元素为 [7, 8, 9]
。
而如果 dim=1
,则第0个元素为 [1, 4, 7]
,第1个元素为 [2, 5, 8]
,第2个元素为 [3, 6, 9]
。
因此,如果有以下代码:
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # (3, 3) index = torch.tensor([0, 0, 1], dtype=torch.int64) out = scatter(src, index, dim=0, reduce='mean')
那么我们就应该将src中的第0个元素为 [1, 2, 3]
和第1个元素为 [4, 5, 6]
求平均为 [2.5, 3.5, 4.5]
,然后第2个元素 [7, 8, 9]
保持不变,即:
tensor([[2.5000, 3.5000, 4.5000], [7.0000, 8.0000, 9.0000]])
2.2 顺序问题
上面的例子中 index=[0, 0, 1]
,最后结果是将src中第0个元素和第1个元素求平均放到了位置0,然后src中第2个元素保持不变放到了位置1。
如果 index=[1, 1, 0]
,结果为:
tensor([[7.0000, 8.0000, 9.0000], [2.5000, 3.5000, 4.5000]])
可以发现,上述结果是将src中第2个元素 [7, 8, 9]
保持不变放到了位置0,然后将src中第0个元素 [1, 2, 3]
和第1个元素 [4, 5, 6]
求平均保持不变放到了位置1。
也就是说,无论index怎么变化,都是优先将index中0对应位置的操作结果进行放置。
2.3 维度问题
如果src的维度为(4, 3),而我们需要对 dim=0
操作,也就是一共有四个元素,那么index的长度应该为4,即以下操作是不合法的:
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) # (4, 3) index = torch.tensor([1, 1, 0], dtype=torch.int64) out = scatter(src, index, dim=0, reduce='mean') print(out)
报错为:
RuntimeError: The expanded size of the tensor (4) must match the existing size (3) at non-singleton dimension 0. Target sizes: [4, 3]. Tensor sizes: [3, 1]
正确做法应该是:
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) # (4, 3) index = torch.tensor([1, 1, 0, 2], dtype=torch.int64) out = scatter(src, index, dim=0, reduce='mean') print(out)
输出为:
tensor([[ 7.0000, 8.0000, 9.0000],
[ 2.5000, 3.5000, 4.5000],
[10.0000, 11.0000, 12.0000]])
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。