python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > 更新tensor指定index位置值scatter_add_

pytorch更新tensor中指定index位置的值scatter_add_问题

作者:腾阳山泥若

这篇文章主要介绍了pytorch更新tensor中指定index位置的值scatter_add_问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

使用scatter_add_更新tensor张量中指定index位置的值

例子

import torch
a = torch.zeros((3, 4))
print(a)
"""
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])
"""
b = torch.rand((2, 4))
print(b)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.5577],
        [0.3469, 0.1025, 0.8185, 0.5085]])
"""
# 将a中第0行和第2行的值修改为b
a = a.scatter_add_(0, torch.tensor([[0, 0, 0], [2, 2, 2]]), b)
print(a)
"""
tensor([[0.6293, 0.3050, 0.9608, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.3469, 0.1025, 0.8185, 0.0000]])
"""

torch_scatter.scatter_add、Tensor.scatter_add_ 、Tensor.scatter_、Tensor.scatter_add 、Tensor.scatter

torch_scatter.scatter_add

官方文档:

torch_scatter.scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0)

Sums all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in input for dimensions outside of dim and by the corresponding value in index for dimension dim. If multiple indices reference the same location, their contributions add.

看着挺疑惑的,自己试了一把:

src = torch.tensor([10, 20, 30, 40, 1, 2, 2, 2, 9])
index = torch.tensor([2, 1, 1, 1, 1, 1, 1, 1, 0])
out=scatter_add(src, index)
print(out)

输出结果为:tensor([ 9, 97, 10])

说白了就是:index就是out的下标,将src所有和此下标对应的值加起来,就是out的值。

例如上面的例子:index中等于1的,对应于src是【20, 30, 40, 1, 2, 2, 2】,将这些值加起来是97,于是,out[1]=97

同理:out[0]=src[8]=9     out[2]=src[0]=10

另一个函数

Tensor.scatter_add_

官方文档:

scatter_add_(self, dim, index, other):
For a 3-D tensor, :attr:`self` is updated as::
    self[index[i][j][k]][j][k] += other[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] += other[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] += other[i][j][k]  # if dim == 2

官方例子:

            >>> x = torch.rand(2, 5)
            >>> x
            tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
                    [0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
            >>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
                    [1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
                    [1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])

以index来遍历,就比较容易看懂。self中并不是每个值都要改变的。

以上面为例

index[0][0]=0  self[index[0][0]][0]=self[0][0] =self[0][0]+ x[0][0]=1 +0.7404=1.7404
index[0][1]=1  self[index[0][1]][1]=self[1][1] =self[1][1]+ x[0][1] =1 +0.0427 =1.0427

。。。

以此类推,将index遍历一遍,就得到最终的结果

所以,self中需要改变的是index中列出的坐标,其他的是不动的。

Tensor.scatter_

scatter_(self, dim, index, src)

和Tensor.scatter_add_的区别是直接将src中的值填充到self中,不做相加

例子:

>>> x = torch.rand(2, 5)
            >>> x
            tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
                    [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
            >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
            tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
                    [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
                    [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])
            >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
            >>> z
            tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
                    [ 0.0000,  0.0000,  0.0000,  1.2300]])

另外,pytorch中还有

scatter_add和scatter函数,和上面两个函数不同的是这个两个函数不改变self,会返回结果值;上面两个函数(scatter_add_和scatter_)是直接在原数据self上进行修改

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文