Pytorch PyG实现EdgePool图分类

 更新时间:2023年04月21日 09:55:27   作者:实力  
这篇文章主要为大家介绍了Pytorch PyG实现EdgePool图分类示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助,祝大家多多进步,早日升职加薪

Python客栈送红包、纸质书

EdgePool简介

EdgePool是一种用于图分类的卷积神经网络(Convolutional Neural Network,CNN)模型。其主要思想是通过 edge pooling 上下采样优化图像大小,减少空间复杂度,提高分类性能。

实现步骤

 数据准备

一般来讲,在构建较大规模数据集时,我们都需要对数据进行规范、归一和清洗处理,以便后续语义分析或深度学习操作。而在图像数据集中,则需使用特定的框架或工具库完成。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 导入MNIST数据集
from torch_geometric.datasets import MNISTSuperpixels
# 加载数据、划分训练集和测试集
dataset = MNISTSuperpixels(root='./mnist', transform=Compose([ToTensor(), NormalizeMeanStd()]))
data = dataset[0]
# 定义超级参数
num_features = dataset.num_features
num_classes = dataset.num_classes
# 构建训练集和测试集索引文件
train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)
train_mask[:60000] = 1
test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)
test_mask[60000:] = 1
# 创建数据加载器
train_loader = DataLoader(data[train_mask], batch_size=32, shuffle=True)
test_loader = DataLoader(data[test_mask], batch_size=32, shuffle=False)

实现模型

在定义EdgePool模型时,我们需要重新考虑网络结构中的上下采样操作,以便让整个网络拥有更强大的表达能力,从而学习到更复杂的关系。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torch.nn import Linear
from torch_geometric.nn import EdgePooling
class EdgePool(torch.nn.Module):
    def __init__(self, dataset):
        super(EdgePool, self).__init__()
        # 定义输入与输出维度数
        self.input_dim = dataset.num_features
        self.hidden_dim = 128
        self.output_dim = 10
        # 定义卷积层、归一化层和pooling层等
        self.conv1 = GCNConv(self.input_dim, self.hidden_dim)
        self.norm1 = BatchNorm1d(self.hidden_dim)
        self.pool1 = EdgePooling(self.hidden_dim)
        self.conv2 = GCNConv(self.hidden_dim, self.hidden_dim)
        self.norm2 = BatchNorm1d(self.hidden_dim)
        self.pool2 = EdgePooling(self.hidden_dim)
        self.conv3 = GCNConv(self.hidden_dim, self.hidden_dim)
        self.norm3 = BatchNorm1d(self.hidden_dim)
        self.pool3 = EdgePooling(self.hidden_dim)
        self.lin = torch.nn.Linear(self.hidden_dim, self.output_dim)
    def forward(self, x, edge_index, batch):
        x = F.relu(self.norm1(self.conv1(x, edge_index)))
        x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
        x = F.relu(self.norm2(self.conv2(x, edge_index)))
        x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
        x = F.relu(self.norm3(self.conv3(x, edge_index)))
        x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x

在上述代码中,我们使用了不同的卷积层、池化层和全连接层等神经网络功能块来构建EdgePool模型。其中,每个 GCNConv 层被保持为128的隐藏尺寸;BatchNorm1d是一种旨在提高收敛速度并增强网络泛化能力的方法;EdgePooling是一种在 GraphConvolution 上附加的特殊类别,它将给定图下采样至其一半的大小,并返回缩小后的图与两个跟踪full-graph-to-pool双向映射(keep and senders)的 edge index(edgendarcs)。 在这种情况下传递 None ,表明 batch 未更改。

模型训练

在定义好 EdgePool 网络结构之后,需要指定合适的优化器、损失函数,并控制训练轮数、批量大小与学习率等超参数。同时还要记录大量日志信息,方便后期跟踪和驾驶员。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 定义训练计划,包括损失函数、优化器及迭代次数等
train_epochs = 50
learning_rate = 0.01
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(edge_pool.parameters(), lr=learning_rate)
losses_per_epoch = []
accuracies_per_epoch = []
for epoch in range(train_epochs):
    running_loss = 0.0
    running_corrects = 0.0
    count = 0.0
    for samples in train_loader:
        optimizer.zero_grad()
        x, edge_index, batch = samples.x, samples.edge_index, samples.batch
        out = edge_pool(x, edge_index, batch)
        label = samples.y
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() / len(train_loader.dataset)
        pred = out.argmax(dim=1)
        running_corrects += pred.eq(label).sum().item() / len(train_loader.dataset)
        count += 1
    losses_per_epoch.append(running_loss)
    accuracies_per_epoch.append(running_corrects)
    if (epoch + 1) % 10 == 0:
        print("Train Epoch {}/{} Loss {:.4f} Accuracy {:.4f}".format(
            epoch + 1, train_epochs, running_loss, running_corrects))

在训练过程中,我们遍历了每个批次的数据,并通过反向传播算法进行优化,并更新了 loss 和 accuracy 输出值。 同时方便可视化与记录,需要将训练过程中的 loss 和 accuracy 输出到相应的容器中,以便后期进行分析和处理。

以上就是Pytorch PyG实现EdgePool图分类的详细内容,更多关于Pytorch PyG EdgePool图分类的资料请关注脚本之家其它相关文章!

蓄力AI

微信公众号搜索 “ 脚本之家 ” ,选择关注

程序猿的那些事、送书等活动等着你

原文链接:https://juejin.cn/post/7224127112709652538

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若内容造成侵权/违法违规/事实不符,请将相关资料发送至 reterry123@163.com 进行投诉反馈,一经查实,立即处理!

相关文章

  • python调用百度地图WEB服务API获取地点对应坐标值

    python调用百度地图WEB服务API获取地点对应坐标值

    这篇文章主要为大家详细介绍了python调用百度地图WEB服务API获取地点对应坐标值,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2019-01-01
  • python pandas模块进行数据分析

    python pandas模块进行数据分析

    Python的Pandas模块是一个强大的数据处理工具,可以用来读取、处理和分析各种数据,本文主要介绍了python pandas模块进行数据分析,具有一定的参考价值,感兴趣的可以了解一下
    2024-01-01
  • Python中bisect的用法及示例详解

    Python中bisect的用法及示例详解

    这篇文章主要介绍了Python中bisect的用法及示例详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-07-07
  • python:列表详解

    python:列表详解

    这篇文章主要介绍了Python中列表(List)的详解操作方法,包含创建、访问、更新、删除、其它操作等,需要的朋友可以参考下
    2021-10-10
  • python处理图片之PIL模块简单使用方法

    python处理图片之PIL模块简单使用方法

    这篇文章主要介绍了python处理图片之PIL模块简单使用方法,涉及Python使用PIL模块实现针对图片的锐化、绘制直线、绘制椭圆等相关技巧,需要的朋友可以参考下
    2015-05-05
  • 基于Python实现的扫雷游戏实例代码

    基于Python实现的扫雷游戏实例代码

    这篇文章主要介绍了基于Python实现的扫雷游戏实例代码,对于Python的学习以及Python游戏开发都有一定的借鉴价值,需要的朋友可以参考下
    2014-08-08
  • python实现数学模型(插值、拟合和微分方程)

    python实现数学模型(插值、拟合和微分方程)

    这篇文章主要介绍了python实现数学模型(插值、拟合和微分方程),本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-11-11
  • 解决pycharm无法识别本地site-packages的问题

    解决pycharm无法识别本地site-packages的问题

    今天小编就为大家分享一篇解决pycharm无法识别本地site-packages的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-10-10
  • Python 多线程抓取图片效率对比

    Python 多线程抓取图片效率对比

    Python由于有全锁局的存在,并不能利用多核优势。所以,如果你的多线程进程是CPU密集型的,那多线程并不能带来效率上的提升,相反还可能会因为线程的频繁切换,导致效率下降;如果是IO密集型,多线程进程可以利用IO阻塞等待时的空闲时间执行其他线程,提升效率。
    2016-02-02
  • Python3离线安装Requests模块问题

    Python3离线安装Requests模块问题

    这篇文章主要介绍了Python3离线安装Requests模块问题,本文实例代码给大家介绍的非常详细,非常不错,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-10-10

最新评论