解决pytorch model代码内tensor device不一致的问题
作者:_LvP
pytorch model代码内tensor device不一致的问题
在编写一段处理两个tensor的代码如下,需要在forward函数内编写函数创建一个新的tensor进行索引的掩码计算
# todo(liang)空间交换 def compute_sim_and_swap(t1, t2, threshold=0.7): n, c, h, w = t1.shape sim = torch.nn.functional.cosine_similarity(t1, t2, dim=1) # n, h, w sim = sim.unsqueeze(0) # c, n, h, w expand_tensor = sim.clone() # 使用拼接构建相同的维度 for _ in range(c-1): # c, n, h, w sim = torch.cat([sim, expand_tensor], dim=0) sim = sim.permute(1, 0, 2, 3) # n, c, h, w # 创建逻辑掩码,小于 threshold 的将掩码变为 True 用于交换 mask = sim < threshold indices = torch.rand(mask.shape) < 0.5 t1[mask&indices], t2[mask&indices] = t2[mask&indices], t1[mask&indices] return t1, t2
这段代码报了这个错误
File "xxx/network.py", line 347, in compute_sim_and_swap
t1[mask&indices], t2[mask&indices] = t2[mask&indices], t1[mask&indices]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
统一下进行掩码计算的张量的设备即可
device = mask.Device indices = indices.to(device)
PyTorch 多GPU使用torch.nn.DataParallel训练参数不一致问题
在多GPU训练时,遇到了下述的错误:
1. Expected tensor for argument 1 'input' to have the same device as tensor for argument 2 'weight'; but device 0 does not equal 1
2. RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!
造成这个错误的可能性有挺多,总起来是模型、输入、模型内参数不在一个GPU上。本人是在调试RandLA-Net pytorch源码,希望使用双GPU训练,经过尝试解决这个问题,此处做一个记录,希望给后来人一个提醒。经过调试,发现报错的地方主要是在数据拼接的时候,即一个数据在GPU0上,一个数据在GPU1上,这就会出现错误,相关代码如下:
return torch.cat(( self.mlp(concat), features.expand(B, -1, N, K) ), dim=-3)
上述代码中,必须保证self.mlp(concat)与features.expand(B, -1, N, K)在同一个GPU中。在多GPU运算时,features(此时是输入变量)有可能放在任何一个GPU中,因此此处在拼接前,获取一下features的GPU,然后将concat放入相应的GPU中,再进行数据拼接就可以了,代码如下:
device = features.device concat = concat.to(device) return torch.cat(( self.mlp(concat), features.expand(B, -1, N, K) ), dim=-3)
该源码中默认状态下device是一个固定的值,在多GPU训练状态下就会报错,代码中还有几处数据融合,大家可以依据上述思路做修改。此外该源码中由于把device的值写死了,训练好的模型也必须在相应的GPU中做推理,如在cuda0中训练的模型如果在cuda1中推理就会报错,各位可以依据此思路对源码做相应的修改。如果修改有困难,可以私信我,我可以把相关修改后的源码分享。
到此这篇关于pytorch model代码内tensor device不一致的问题的文章就介绍到这了,更多相关pytorch tensor device不一致内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!