东方耀AI技术分享

标题: 08、pytorch模型加载与保存 [打印本页]

作者: 东方耀    时间: 2020-5-25 10:32
标题: 08、pytorch模型加载与保存
pytorch模型加载与保存


pytorch模型加载与保存


保存和与加载模型,有三个核心函数需要熟悉:


torch.save: 保存一个序列化的对象至硬盘,。该函数使用了Python的pickle包用于序列化。模型、张量和各种对象的字典都可以使用该函数保存;
torch.load:使用pickle的反序列化功能将序列化的对象文件反序列化到内存。这个功能可以用帮助设备加载数据。
torch.nn.Module.load_state_dict:使用一个反序列化的state_dict加载一个模型的参数字典。




# 用于测试的模型加载与保存
# 保存/加载 state_dict(推荐)
torch.save(model.state_dict(), "mnist.pth")


# 使用torch.save()保存模型的state_dict可以在加载模型时给予最大的灵活度,
# 这也是为什么推荐使用的保存模型的方法。




# 注意:load_state_dict()函数的参数是一个字典对象,不是保存对象的文件名。
# 这意味着,在将保存的state_dict传递给load_state_dict()之前必须将其反序列化
model = 网络结构的build()
model.load_state_dict(torch.load("mnist.pth"))
model.eval()





pytorch模型存储的两种方式
1.保存整个网络结构信息和模型参数信息:

torch.save(model_object, './model.pth')
直接加载即可使用:

model = torch.load('./model.pth')
model.eval()

这种保存/加载的方法使用了最直观的语法,涉及到最少量的程序。使用这种方法保存整个模型使用了Python的pickle模块。这种方法的劣势是,序列化的数据被限制于模型包存时所使用的特定类型和准确的字典结构。这是因为pickle不能保存模型类型本身。而是保存一个加载时使用的包含类的文件的地址。因此,你的程序在其他项目中使用、或改变之后可能会各种崩溃。

**切记:**在测试之前, 必须调用model.eval()将dropout和batch normalization层设置为测试模式,不然会导致不一致的结果



2.只保存网络的模型参数-推荐使用

torch.save(model_object.state_dict(), './params.pth')
加载则要先从本地网络模块导入网络,然后再加载参数:

model = 构建网络结构的build()
model.load_state_dict(torch.load('./params.pth'))
















欢迎光临 东方耀AI技术分享 (http://www.ai111.vip/) Powered by Discuz! X3.4