保存和加载整个模型
torch.save(model_object, 'model.pkl') model = torch.load('model.pkl')
仅保存和加载模型参数(推荐使用,需要提前手动构建模型)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
但是要注意几个细节:
1.若使用nn.DataParallel在一台电脑上使用了多个GPU,load模型的时候也必须先DataParallel,这和keras类似。
2.load提供了很多重载的功能,其可以把在GPU上训练的权重加载到CPU上跑。内容参考于:
https://www.ptorch.com/news/74.html <https://www.ptorch.com/news/74.html>
torch.load('tensors.pt') # 把所有的张量加载到CPU中 torch.load('tensors.pt',
map_location=lambda storage, loc: storage)# 把所有的张量加载到GPU 1中 torch.load(
'tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # 把张量从GPU 1
移动到 GPU 0 torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
在cpu上加载预先训练好的GPU模型,有一种强制所有GPU张量在CPU中的方式:
torch.load('my_file.pt', map_location=lambda storage, loc: storage)
上述代码只有在模型在一个GPU上训练时才起作用。如果我在多个GPU上训练我的模型,保存它,然后尝试在CPU上加载,我得到这个错误:KeyError:
‘unexpected key “module.conv1.weight” in state_dict’ 如何解决?
您可能已经使用模型保存了模型nn.DataParallel,该模型将模型存储在该模型中module,而现在您正试图加载模型DataParallel。您可以nn.DataParallel在网络中暂时添加一个加载目的,也可以加载权重文件,创建一个没有module前缀的新的有序字典,然后加载它。
# original saved file with DataParallel state_dict = torch.load(
'myfile.pth.tar') # create new OrderedDict that does not contain `module.` from
collections import OrderedDict new_state_dict = OrderedDict()for k, v in
state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v #
loadparams model.load_state_dict(new_state_dict)
笔者封装了一个简单的函数,可以直接加载多GPU权重到CPU上(只加载匹配的权重)
# 加载模型,解决命名和维度不匹配问题,解决多个gpu并行 def load_state_keywise(model, model_path):
model_dict = model.state_dict() pretrained_dict = torch.load(model_path,
map_location='cpu') key = list(pretrained_dict.keys())[0] # 1. filter out
unnecessary keys # 1.1 multi-GPU ->CPU if (str(key).startswith('module.')):
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in
model_dictand v.size() == model_dict[k[7:]].size()} else: pretrained_dict = {k:
vfor k, v in pretrained_dict.items() if k in model_dict and v.size() ==
model_dict[k].size()}# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)# 3. load the new state dict
model.load_state_dict(model_dict)
有朋友问,上面为什么去掉‘’module.‘’就可以在单GPU上跑了,看下面的一个栗子
import torch from torch import nn import torchvision #使用alexnet做测试,使用单个GPU或CUP
alexnet=torchvision.models.alexnet() state_dict=alexnet.state_dict()for k, v in
state_dict.items(): print(k) print("-"*20) #华丽的分割线 #使用多GPU model =
nn.DataParallel(alexnet) state_dict=model.state_dict()for k, v in
state_dict.items(): print(k)
看结果就知道了,其就多了个前缀‘module.’
features.0.weight features.0.bias features.3.weight features.3.bias features.6
.weight features.6.bias features.8.weight features.8.bias features.10.weight
features.10.bias classifier.1.weight classifier.1.bias classifier.4.weight
classifier.4.bias classifier.6.weight classifier.6.bias --------------------
module.features.0.weight module.features.0.bias module.features.3.weight module
.features.3.bias module.features.6.weight module.features.6.bias module.features
.8.weight module.features.8.bias module.features.10.weight module.features.10
.bias module.classifier.1.weight module.classifier.1.bias module.classifier.4
.weight module.classifier.4.bias module.classifier.6.weight module.classifier.6
.bias
热门工具 换一换