logo

PyTorch:张量与模型的保存与加载

作者:很酷cat2023.11.08 12:41浏览量:99

简介:pytorch保存张量 pytorch保存模型用于继续训练

pytorch保存张量 pytorch保存模型用于继续训练
PyTorch是一种广泛使用的深度学习框架,它支持张量的保存和加载,以及模型的保存和加载。这些功能使得PyTorch非常适合用于训练和部署深度学习模型。在本文中,我们将讨论如何使用PyTorch保存张量和模型,以便在以后的训练中继续使用。

保存和加载张量

在PyTorch中,可以使用torch.save()函数将张量保存到磁盘上。这个函数接受两个参数:要保存的张量和保存的路径。例如,以下代码将一个名为x的张量保存到名为x.pt的文件中:

  1. import torch
  2. x = torch.tensor([1, 2, 3])
  3. torch.save(x, 'x.pt')

要加载保存的张量,可以使用torch.load()函数。例如,以下代码将加载刚刚保存的x.pt文件:

  1. x = torch.load('x.pt')
  2. print(x)

这将输出tensor([1, 2, 3])

保存和加载模型

与张量的保存和加载类似,PyTorch也提供了用于保存和加载模型的函数。可以使用torch.save()函数将模型保存到磁盘上。以下是一个简单的例子:

  1. import torch.nn as nn
  2. class Net(nn.Module):
  3. def __init__(self):
  4. super(Net, self).__init__()
  5. self.fc1 = nn.Linear(10, 5)
  6. self.fc2 = nn.Linear(5, 2)
  7. def forward(self, x):
  8. x = self.fc1(x)
  9. x = self.fc2(x)
  10. return x
  11. model = Net()
  12. torch.save(model, 'model.pt')

这将把一个名为model.pt的文件保存到模型对象model中。要加载保存的模型,可以使用torch.load()函数:

  1. model = torch.load('model.pt')

现在,您可以使用加载的模型进行推理或继续训练。如果您想在加载模型后继续训练,您需要将模型设置为训练模式并设置正确的优化器和损失函数。例如:

  1. model.train()
  2. optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
  3. criterion = nn.CrossEntropyLoss()

相关文章推荐

发表评论