Pytorch PyG实现EdgePool图分类
EdgePool简介
EdgePool是一种用于图分类的卷积神经网络(Convolutional Neural Network,CNN)模型。其主要思想是通过 edge pooling 上下采样优化图像大小,减少空间复杂度,提高分类性能。
实现步骤
数据准备
一般来讲,在构建较大规模数据集时,我们都需要对数据进行规范、归一和清洗处理,以便后续语义分析或深度学习操作。而在图像数据集中,则需使用特定的框架或工具库完成。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | # 导入MNIST数据集 from torch_geometric.datasets import MNISTSuperpixels # 加载数据、划分训练集和测试集 dataset = MNISTSuperpixels(root = './mnist' , transform = Compose([ToTensor(), NormalizeMeanStd()])) data = dataset[ 0 ] # 定义超级参数 num_features = dataset.num_features num_classes = dataset.num_classes # 构建训练集和测试集索引文件 train_mask = torch.zeros(data.num_nodes, dtype = torch.uint8) train_mask[: 60000 ] = 1 test_mask = torch.zeros(data.num_nodes, dtype = torch.uint8) test_mask[ 60000 :] = 1 # 创建数据加载器 train_loader = DataLoader(data[train_mask], batch_size = 32 , shuffle = True ) test_loader = DataLoader(data[test_mask], batch_size = 32 , shuffle = False ) |
实现模型
在定义EdgePool模型时,我们需要重新考虑网络结构中的上下采样操作,以便让整个网络拥有更强大的表达能力,从而学习到更复杂的关系。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | from torch.nn import Linear from torch_geometric.nn import EdgePooling class EdgePool(torch.nn.Module): def __init__( self , dataset): super (EdgePool, self ).__init__() # 定义输入与输出维度数 self .input_dim = dataset.num_features self .hidden_dim = 128 self .output_dim = 10 # 定义卷积层、归一化层和pooling层等 self .conv1 = GCNConv( self .input_dim, self .hidden_dim) self .norm1 = BatchNorm1d( self .hidden_dim) self .pool1 = EdgePooling( self .hidden_dim) self .conv2 = GCNConv( self .hidden_dim, self .hidden_dim) self .norm2 = BatchNorm1d( self .hidden_dim) self .pool2 = EdgePooling( self .hidden_dim) self .conv3 = GCNConv( self .hidden_dim, self .hidden_dim) self .norm3 = BatchNorm1d( self .hidden_dim) self .pool3 = EdgePooling( self .hidden_dim) self .lin = torch.nn.Linear( self .hidden_dim, self .output_dim) def forward( self , x, edge_index, batch): x = F.relu( self .norm1( self .conv1(x, edge_index))) x, edge_index, _, batch, _ = self .pool1(x, edge_index, None , batch) x = F.relu( self .norm2( self .conv2(x, edge_index))) x, edge_index, _, batch, _ = self .pool2(x, edge_index, None , batch) x = F.relu( self .norm3( self .conv3(x, edge_index))) x, edge_index, _, batch, _ = self .pool3(x, edge_index, None , batch) x = global_mean_pool(x, batch) x = self .lin(x) return x |
在上述代码中,我们使用了不同的卷积层、池化层和全连接层等神经网络功能块来构建EdgePool模型。其中,每个 GCNConv 层被保持为128的隐藏尺寸;BatchNorm1d是一种旨在提高收敛速度并增强网络泛化能力的方法;EdgePooling是一种在 GraphConvolution 上附加的特殊类别,它将给定图下采样至其一半的大小,并返回缩小后的图与两个跟踪full-graph-to-pool双向映射(keep and senders)的 edge index(edgendarcs)。 在这种情况下传递 None ,表明 batch
未更改。
模型训练
在定义好 EdgePool 网络结构之后,需要指定合适的优化器、损失函数,并控制训练轮数、批量大小与学习率等超参数。同时还要记录大量日志信息,方便后期跟踪和驾驶员。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | # 定义训练计划,包括损失函数、优化器及迭代次数等 train_epochs = 50 learning_rate = 0.01 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(edge_pool.parameters(), lr = learning_rate) losses_per_epoch = [] accuracies_per_epoch = [] for epoch in range (train_epochs): running_loss = 0.0 running_corrects = 0.0 count = 0.0 for samples in train_loader: optimizer.zero_grad() x, edge_index, batch = samples.x, samples.edge_index, samples.batch out = edge_pool(x, edge_index, batch) label = samples.y loss = criterion(out, label) loss.backward() optimizer.step() running_loss + = loss.item() / len (train_loader.dataset) pred = out.argmax(dim = 1 ) running_corrects + = pred.eq(label). sum ().item() / len (train_loader.dataset) count + = 1 losses_per_epoch.append(running_loss) accuracies_per_epoch.append(running_corrects) if (epoch + 1 ) % 10 = = 0 : print ( "Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}" . format ( epoch + 1 , train_epochs, running_loss, running_corrects)) |
在训练过程中,我们遍历了每个批次的数据,并通过反向传播算法进行优化,并更新了 loss 和 accuracy 输出值。 同时方便可视化与记录,需要将训练过程中的 loss 和 accuracy 输出到相应的容器中,以便后期进行分析和处理。
以上就是Pytorch PyG实现EdgePool图分类的详细内容,更多关于Pytorch PyG EdgePool图分类的资料请关注脚本之家其它相关文章!

微信公众号搜索 “ 脚本之家 ” ,选择关注
程序猿的那些事、送书等活动等着你
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若内容造成侵权/违法违规/事实不符,请将相关资料发送至 reterry123@163.com 进行投诉反馈,一经查实,立即处理!
相关文章
解决pycharm无法识别本地site-packages的问题
今天小编就为大家分享一篇解决pycharm无法识别本地site-packages的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧2018-10-10
最新评论