pytorch中的scatter_add_函数的使用解读
作者:*Lisen
pytorch scatter_add_函数的使用
关于这个函数,很少博客详细的介绍。下面就我个人理解简单介绍下。
函数:
self_tensor.scatter_add_(dim, index_tensor, other_tensor) → 输出tensor
该函数的意思是:
将other_tensor中的数据,按照index_tensor中的索引位置,添加至self_tensor矩阵中。
参数:
dim
:表示需要改变的维度,但是注意,假如dim=1,并不是说self_tensor在dim=0上的数据不会改变,这个dim只是在取矩阵数据时不固定dim=1维的索引,使用index_tensor矩阵中的索引。可能这样说还是不太理解,下面会用例子说明。其中self_tensor表示我们需要改变的tensor矩阵index_tensor
:索引矩阵;other_tensor
:需要添加到self_tensor中的tensor
要求:
1、self_tensor,index_tensor, other_tensor 的维度需要相同,即self.tensor.dim() = index_tensor.dim() = other_tensor.dim();
2、假设dim=d,那么index_tensor矩阵中的所有数据必须小于d-1;
3、假设dim=d,index_tensor矩阵在d维度上的size必须小于self_tensor和other_tensor的size;即index.size(d) <= other_tensor.size(d) 且index.size(d) <= self_tensor.size(d)
三维计算公式:
self[index[i][j][k]][j][k] += other[i][j][k] # 如果 dim == 0 self[i][index[i][j][k]][k] += other[i][j][k] # 如果 dim == 1 self[i][j][index[i][j][k]] += other[i][j][k] # 如果 dim == 2
二维计算公式:
self[index[i][j]][j] += other[i][j] # 如果 dim == 0 self[i][index[i][j]] += other[i][j] # 如果 dim == 1
index_tensor = torch.tensor([[0,1],[1,1]]) print('index_tensor: \n', index_tensor) self_tensor = torch.arange(0, 4).view(2, 2) print('self_tensor: \n', self_tensor) other_tensor = torch.arange(5, 9).view(2, 2) print('other_tensor: \n', other_tensor) dim = 0 for i in range(index_tensor.size(0)): for j in range(index_tensor.size(1)): replace_index = index_tensor[i][j] if dim == 0: # self矩阵的第0维索引 self_tensor[replace_index][j] += other_tensor[i][j] elif dim == 1: # self矩阵的第1维索引 self_tensor[i][replace_index] += other_tensor[i][j] print(self_tensor)
结果:
index_tensor:
tensor([[0, 1],
[1, 1]])
self_tensor:
tensor([[0, 1],
[2, 3]])
other_tensor:
tensor([[5, 6],
[7, 8]])
tensor([[ 5, 1],
[ 9, 17]])
使用函数计算:
index_tensor = torch.tensor([[0,1],[1,1]]) print('index_tensor: \n', index_tensor) self_tensor = torch.arange(0, 4).view(2, 2) print('self_tensor: \n', self_tensor) other_tensor = torch.arange(5, 9).view(2, 2) print('other_tensor: \n', other_tensor) self_tensor.scatter_add_(0, index_tensor, other_tensor) print(self_tensor)
结果:
index_tensor:
tensor([[0, 1],
[1, 1]])
self_tensor:
tensor([[0, 1],
[2, 3]])
other_tensor:
tensor([[5, 6],
[7, 8]])
tensor([[ 5, 1],
[ 9, 17]])
scatter_add()函数通俗理解
self [ index[i,j] , j ] += src [ i , j ] # if dim == 0 self [ i , index[i,j] ] += src[ i, j ] # if dim == 1
理解scatter_add()函数,看index就行了,index有多少个,self坐标就会变多少次。
self是一个二维的数组,self[第一维,第二维],dim==0,就是将src对应坐标,对应到 index 坐标里面的值,放置到self的第一维中。
例如:
src[i,j]对应到index[i,j],假设 index[i,j] ==0,则self[第一维,第二维] 为self[0,j],只改变第一维,第二维的值和src第二维一样。
然后self[0,j]的值就会变为 self[0,j]=self[0,j]+src[i,j]
代码中 self=torch.zeros(3,5), dim=0, index=[0,1,2,0,0], src=torch.ones(2,5)
我们只看 src,当 src[0,0]=1, index[0,0]=0, self[0,0]=self[0,0]+src[0,0]=1。
当src[0,1]=1, index[0,1]=1, self[1,1]=self[1,1]+src[0,1]=1, self的第一维是index的值决定的为1,第二维是src的第二维坐标决定也为1。
当index的值没有时,就停止变换,self没有变换过的坐标值就保持不变。
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。