pytorch加载预训练模型与自己模型不匹配的解决方案
作者:找不到服务器1703
这篇文章主要介绍了pytorch加载预训练模型与自己模型不匹配的解决方案,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教
pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。
两个有序字典找不同
模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。
model = ResNet18(1) model_dict1 = torch.load('resnet18.pth') model_dict2 = model.state_dict() model_list1 = list(model_dict1.keys()) model_list2 = list(model_dict2.keys()) len1 = len(model_list1) len2 = len(model_list2) minlen = min(len1, len2) for n in range(minlen): if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape: err = 1
自己搭建模型的注意事项
搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。
model = ResNet18(1) model_dict1 = torch.load('resnet18.pth') model_dict2 = model.state_dict() model_list1 = list(model_dict1.keys()) model_list2 = list(model_dict2.keys()) len1 = len(model_list1) len2 = len(model_list2) minlen = min(len1, len2) for n in range(minlen): if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape: continue model_dict1[model_list1[n]] = model_dict2[model_list2[n]] model.load_state_dict(model_dict2)
完整的代码见自己搭建resnet18网络并加载torchvision自带权重
新增的改进代码
model_dict1 = torch.load('yolov5.pth') model_dict2 = model.state_dict() model_list1 = list(model_dict1.keys()) model_list2 = list(model_dict2.keys()) len1 = len(model_list1) len2 = len(model_list2) m, n = 0, 0 while True: if m >= len1 or n >= len2: break layername1, layername2 = model_list1[m], model_list2[n] w1, w2 = model_dict1[layername1], model_dict2[layername2] if w1.shape != w2.shape: continue model_dict2[layername2] = model_dict1[layername1] m += 1 n += 1 model.load_state_dict(model_dict2)
如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。
补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配
看代码吧~
#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层, #以及到第二个全连接层的全部网络还有他们对应的参数 class Classification_att(nn.Module): def __init__(self, rgb_range): super(Classification_att, self).__init__() self.vgg19 =models.vgg19(pretrained=True) vgg = models.vgg19(pretrained=True).features conv_modules = [m for m in vgg] self.vgg_conv = nn.Sequential(*conv_modules[:37]) classfi = models.vgg19(pretrained=True).classifier classif_modules = [n for n in classfi] self.vgg_class = nn.Sequential(*classif_modules[:4]) vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) for p in self.vgg_conv.parameters(): p.requires_grad = False for p in self.vgg_class.parameters(): p.requires_grad = False self.classifi = nn.Sequential( nn.Linear(4096, 1024), nn.ReLU(True), nn.Linear(1024, 256), nn.ReLU(True), nn.Linear(256, 64), ) def forward(self, x): x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', align_corners=False) x = self.sub_mean(x) x = self.vgg_conv(x) x = self.vgg_class(x) #执行这部报错,说张量不匹配
原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的
查看vgg的pytorch源码发现是
x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) #自己的代码没有torch.flatten(x, 1)这步
所以自己的少了一步
x = torch.flatten(x, 1)
补上就好了!
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。