Pytorch Geometric实践——利用Pytorch搭建GNN

作者:渣渣辉2024.01.07 17:23浏览量:28

简介:本文将介绍如何使用PyTorch Geometric(PyG)搭建图神经网络(GNN)。我们将从安装PyG开始,逐步深入到构建基本的GNN模型,并通过一个实例展示如何使用PyG进行图数据预处理和模型训练。

千帆应用开发平台“智能体Pro”全新上线 限时免费体验

面向慢思考场景,支持低代码配置的方式创建“智能体Pro”应用

立即体验

PyTorch Geometric(PyG)中,我们可以方便地实现各种图神经网络(GNN)模型。PyG基于PyTorch,提供了丰富的图神经网络层和算法,使得我们能够快速搭建和训练GNN。下面,我们将分步骤介绍如何使用PyG搭建一个简单的GNN模型。
一、安装PyTorch Geometric
要使用PyG,首先需要安装PyTorch和PyG。在终端中输入以下命令进行安装:

  1. pip install torch torchvision

在安装好PyTorch之后,可以使用以下命令安装PyG:

  1. pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0.html

二、数据预处理
在搭建GNN之前,我们需要对图数据进行预处理。PyG提供了torch_geometric.data模块来处理图数据。下面是一个简单的示例,演示如何加载数据并将其转换为PyG所需的格式:

  1. from torch_geometric.datasets import Planetoid
  2. from torch_geometric.data import DataLoader
  3. dataset = Planetoid(root='/tmp/Cora', name='Cora') # 加载Cora数据集
  4. loader = DataLoader(dataset, batch_size=32, shuffle=True) # 创建数据加载器

三、搭建GNN模型
接下来,我们将使用PyG搭建一个简单的GNN模型。首先,导入所需的模块:

  1. from torch_geometric.nn import GCNConv
  2. from torch_geometric.data import DataLoader

然后,定义一个简单的GCN模型:

  1. def create_model():
  2. return GCNConv(dataset.num_features, dataset.num_classes)

四、训练模型
现在,我们可以开始训练模型了。首先,创建一个优化器和一个损失函数:

  1. import torch.optim as optim
  2. from torch_geometric.nn import loss as gnn_loss
  3. optimizer = optim.Adam(model.parameters(), lr=0.01)
  4. loss_fn = gnn_loss.CrossEntropyLoss()

接下来,在训练循环中迭代优化器、损失函数和模型:

  1. best_model = None
  2. best_acc = -1.
  3. for epoch in range(200):
  4. model.train()
  5. optimizer.zero_grad()
  6. out = model(data.x, data.edge_index)
  7. loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
  8. loss.backward()
  9. optimizer.step()
  10. train_acc = (out[data.train_mask].argmax(dim=1) == data.y[data.train_mask]).sum().item() / data.y[data.train_mask].size(0)
  11. if train_acc > best_acc:
  12. best_acc = train_acc
  13. best_model = model
  14. print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Accuracy: {train_acc:.4f}')
article bottom image

相关文章推荐

发表评论