东方耀AI技术分享

 找回密码
 立即注册

QQ登录

只需一步,快速开始

搜索
热搜: 活动 交友 discuz
查看: 2112|回复: 0
打印 上一主题 下一主题

[PyTorch] 08、pytorch模型加载与保存

[复制链接]

1365

主题

1856

帖子

1万

积分

管理员

Rank: 10Rank: 10Rank: 10

积分
14435
QQ
跳转到指定楼层
楼主
发表于 2020-5-25 10:32:20 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
pytorch模型加载与保存


pytorch模型加载与保存


保存和与加载模型,有三个核心函数需要熟悉:


torch.save: 保存一个序列化的对象至硬盘,。该函数使用了Python的pickle包用于序列化。模型、张量和各种对象的字典都可以使用该函数保存;
torch.load:使用pickle的反序列化功能将序列化的对象文件反序列化到内存。这个功能可以用帮助设备加载数据。
torch.nn.Module.load_state_dict:使用一个反序列化的state_dict加载一个模型的参数字典。




# 用于测试的模型加载与保存
# 保存/加载 state_dict(推荐)
torch.save(model.state_dict(), "mnist.pth")


# 使用torch.save()保存模型的state_dict可以在加载模型时给予最大的灵活度,
# 这也是为什么推荐使用的保存模型的方法。




# 注意:load_state_dict()函数的参数是一个字典对象,不是保存对象的文件名。
# 这意味着,在将保存的state_dict传递给load_state_dict()之前必须将其反序列化
model = 网络结构的build()
model.load_state_dict(torch.load("mnist.pth"))
model.eval()





pytorch模型存储的两种方式
1.保存整个网络结构信息和模型参数信息:

torch.save(model_object, './model.pth')
直接加载即可使用:

model = torch.load('./model.pth')
model.eval()

这种保存/加载的方法使用了最直观的语法,涉及到最少量的程序。使用这种方法保存整个模型使用了Python的pickle模块。这种方法的劣势是,序列化的数据被限制于模型包存时所使用的特定类型和准确的字典结构。这是因为pickle不能保存模型类型本身。而是保存一个加载时使用的包含类的文件的地址。因此,你的程序在其他项目中使用、或改变之后可能会各种崩溃。

**切记:**在测试之前, 必须调用model.eval()将dropout和batch normalization层设置为测试模式,不然会导致不一致的结果



2.只保存网络的模型参数-推荐使用

torch.save(model_object.state_dict(), './params.pth')
加载则要先从本地网络模块导入网络,然后再加载参数:

model = 构建网络结构的build()
model.load_state_dict(torch.load('./params.pth'))











让天下人人学会人工智能!人工智能的前景一片大好!
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|手机版|小黑屋|人工智能工程师的摇篮 ( 湘ICP备2020019608号-1 )

GMT+8, 2024-4-26 18:12 , Processed in 0.170049 second(s), 18 queries .

Powered by Discuz! X3.4

© 2001-2017 Comsenz Inc.

快速回复 返回顶部 返回列表