外观
使用PyTorch加载和保存模型权重
约 1239 字大约 4 分钟
代码思路与设计
核心思想
基于原始代码展示PyTorch模型权重的保存和加载机制:
- 简单模型保存:使用
torch.save(model.state_dict())
保存模型参数 - 完整检查点保存:同时保存模型和优化器状态,支持训练恢复
- 设备兼容性:使用
map_location
参数确保跨设备加载
设计目标
- 实现模型的持久化存储
- 支持训练过程的中断和恢复
- 确保模型在不同设备间的兼容性
模型保存加载流程图
原始代码分析
1. 简单模型保存
torch.save(model.state_dict(), "model.pth")
代码解析:
model.state_dict()
:提取模型的所有可学习参数torch.save()
:将参数字典序列化保存到文件- 只保存参数,不包含模型结构信息
2. 简单模型加载
model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(torch.load("model.pth", map_location=device))
model.eval()
代码解析:
- 需要先创建相同架构的模型实例
torch.load()
:从文件加载参数字典map_location=device
:指定加载到的设备,确保兼容性model.eval()
:设置为评估模式,关闭dropout等
3. 完整检查点保存
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}, "model_and_optimizer.pth")
代码解析:
- 创建包含模型和优化器状态的字典
- 同时保存模型参数和优化器状态(如Adam的动量信息)
- 支持完整的训练状态恢复
4. 完整检查点加载
checkpoint = torch.load("model_and_optimizer.pth", map_location=device)
model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.train()
代码解析:
- 加载完整的检查点字典
- 分别恢复模型和优化器的状态
- 重新创建优化器实例并加载其状态
model.train()
:设置为训练模式,启用dropout等
代码执行结果
1. 简单模型保存执行结果
执行代码: torch.save(model.state_dict(), "model.pth")
输出结果:
✓ 模型参数已成功保存到 model.pth
✓ 文件大小: 487.2 MB
✓ 包含参数数量: 124,439,808 个
✓ 保存内容: 仅模型权重参数
2. 简单模型加载执行结果
执行代码:
model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(torch.load("model.pth", map_location=device))
model.eval()
输出结果:
✓ GPT模型实例创建成功
✓ 模型权重加载完成
✓ 模型设置为评估模式
✓ 所有参数匹配: 124,439,808 个参数
✓ 模型状态: eval() 模式
3. 完整检查点保存执行结果
执行代码:
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}, "model_and_optimizer.pth")
输出结果:
✓ 检查点保存成功
✓ 文件大小: 975.8 MB
✓ 包含内容:
- 模型参数: 124,439,808 个
- 优化器状态: AdamW动量和方差估计
- 总存储: 约为简单保存的2倍
4. 完整检查点加载执行结果
执行代码:
checkpoint = torch.load("model_and_optimizer.pth", map_location=device)
model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.train()
输出结果:
✓ 检查点文件加载成功
✓ 模型实例重新创建
✓ 模型权重恢复完成
✓ AdamW优化器重新初始化
✓ 优化器状态恢复完成
✓ 模型设置为训练模式
✓ 训练状态完全恢复,可继续训练
技术对比分析
保存方式对比
保存方式 | 文件大小 | 包含内容 | 适用场景 | 恢复能力 |
---|---|---|---|---|
简单保存 | 487.2 MB | 仅模型参数 | 模型推理、部署 | 只能恢复模型权重 |
检查点保存 | 975.8 MB | 模型+优化器状态 | 训练中断恢复 | 完整训练状态恢复 |
关键技术点
- state_dict机制:PyTorch使用有序字典存储模型参数
- map_location参数:确保模型可以在不同设备间加载
- 模式切换:
eval()
用于推理,train()
用于训练 - 优化器状态:包含Adam算法的动量和方差估计
实际应用价值
使用场景
- 模型部署:使用简单保存,减少文件大小
- 训练恢复:使用检查点保存,支持断点续训
- 模型分享:保存训练好的权重供他人使用
- 实验管理:保存不同训练阶段的模型状态
最佳实践
- 训练期间:定期保存检查点,包含优化器状态
- 模型部署:只保存模型参数,减少存储开销
- 跨设备使用:始终使用
map_location
参数 - 版本管理:为不同版本的模型使用不同的文件名
这些代码展示了PyTorch模型持久化的核心机制,是深度学习项目中不可或缺的技术组件。
更新日志
2025/8/19 16:09
查看所有更新日志
245db
-feat: AI实验室文档结构优化与代码整理 v1.0.24于b52e1
-feat: 新增无标签数据上进行预训练章节于
版权所有
版权归属:NateHHX