python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > Pytorch加载部分权重

Pytorch如何加载部分权重

作者:Mr_寒路

这篇文章主要介绍了Pytorch如何加载部分权重问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

1.修改网络层输出

比如在人脸检测项目中,已经训练好人脸框的回归,但是此时需要再加入人脸关键点。

为了节约大量时间,我们可以加载部分权重。加载的网络权重

if os.path.exists(self.load_params):
	pretext_model = torch.load(self.load_params)

打印出来,会看到网络权重存储在一个字典中,需要修改哪一层,用字典的键索引值进行修改。

比如原本输出层为4,我将网络输出层修改为14,又由于输出的都是坐标值,属于同一分布,所以我将原参4复制扩充为了14,效果非常好。

w = pretext_model["fc2.weight"]
b = pretext_model["fc2.bias"]
pretext_model["fc2.weight"] = torch.cat((w,w,w,w[:2]),dim=0)
pretext_model["fc2.bias"] = torch.cat((b,b,b,b[:2]),dim=0)

最后加载修改后的参数

self.net.load_state_dict(pretext_model)

2.删除或增加了网络层

查看模型的参数,也是存放在一个字典中

if os.path.exists(self.load_params):
	pretext_model = torch.load(self.load_params) #加载的参数
	model_dict = net.state_dict()  #模型参数
	print(model_dict)
	print(pretext_model)
#如果模型有k层,就加载
state_dict = {k: v for k, v in pretext_model.items() if k in model_dict.keys()}
model_dict.update(state_dict)
net.load_state_dict(model_dict)

3.迁移学习

有时我们也会用别人的模型,加载与训练参数,但是需要对输出层做一些修改,一般有两种方法,直接修改输出层个数或增加网络层

修改输出层个数

net = models.vgg19(pretrained=True) #下载与训练参数
print(net)  #查看网络结构
net.classifier[6] = torch.nn.Linear(4096,10) #将输出层修改为10分类

增加输出网络层

num_fc_ftr = net.classifier[6]
net.fc = nn.Linear(num_fc_ftr, 128)
net.out = nn.Linear(128, 10)

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文