PyTorch:张量与模型的保存与加载
2023.11.08 12:41浏览量:99简介:pytorch保存张量 pytorch保存模型用于继续训练
pytorch保存张量 pytorch保存模型用于继续训练
PyTorch是一种广泛使用的深度学习框架,它支持张量的保存和加载,以及模型的保存和加载。这些功能使得PyTorch非常适合用于训练和部署深度学习模型。在本文中,我们将讨论如何使用PyTorch保存张量和模型,以便在以后的训练中继续使用。
保存和加载张量
在PyTorch中,可以使用torch.save()
函数将张量保存到磁盘上。这个函数接受两个参数:要保存的张量和保存的路径。例如,以下代码将一个名为x
的张量保存到名为x.pt
的文件中:
import torch
x = torch.tensor([1, 2, 3])
torch.save(x, 'x.pt')
要加载保存的张量,可以使用torch.load()
函数。例如,以下代码将加载刚刚保存的x.pt
文件:
x = torch.load('x.pt')
print(x)
这将输出tensor([1, 2, 3])
。
保存和加载模型
与张量的保存和加载类似,PyTorch也提供了用于保存和加载模型的函数。可以使用torch.save()
函数将模型保存到磁盘上。以下是一个简单的例子:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = Net()
torch.save(model, 'model.pt')
这将把一个名为model.pt
的文件保存到模型对象model
中。要加载保存的模型,可以使用torch.load()
函数:
model = torch.load('model.pt')
现在,您可以使用加载的模型进行推理或继续训练。如果您想在加载模型后继续训练,您需要将模型设置为训练模式并设置正确的优化器和损失函数。例如:
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
发表评论
登录后可评论,请前往 登录 或 注册