python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > PyTorch的Dataloader与Dataset

浅谈PyTorch的数据读取机制Dataloader与Dataset

作者:木千之

这篇文章主要介绍了浅谈PyTorch的数据读取机制Dataloader与Dataset,DataLoader的作用是构建一个可迭代的数据装载器,每次执行循环的时候,就从中读取一批Batchsize大小的样本进行训练,需要的朋友可以参考下

Dataloader与DataSet数据读取方法

DataLoader与DataSet是PyTorch数据读取的核心。

torch.utils.DataLoader”的作用是构建一个可迭代的数据装载器,每次执行循环的时候,就从中读取一批Batchsize大小的样本进行训练。

其主要参数有五项:

  1. dataset:隶属DataSet类,表示数据从哪里读取以及如何读取
  2. batchsize:批大小
  3. num_works:是否多进程读取数据
  4. shuffle:每个epoch是否乱序
  5. drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

上述主要参数中num_works通常在单进程时默认为“0”,也可以在支持多进程的设备上设置为允许的“4 | 8 | 16”等。

shuffle则通常设置为使用乱序(True),以使得每次数据读取具有随机性。

这里颇为重要的是“Epoch、Iteration和Batchsize”之间的关系:

1)Epoch表示所有训练样本都已输入到模型中,记为一个Epoch;

2)Iteration表示一批样本输入到模型中,记为一个Iteration;

3)Batchsize表示批大小,决定一个Epoch中有多少个Iteration。当样本数可以被Batchsize整除时,三者成立关系,即全体样本分成Batchsize分批次输入模型,每批次记为一次Iteration。

若样本总数80个,当Batchsize=8时,可以知道“1 Epoch = 10 Iteration”。

若样本总数87个,当Batchsize-8时,可以知道:1)若“drop_last=True”,则“1 Epoch = 10 Iteration”;2)

若“drop_last=False”,则“1 Epoch = 11 Iteration”,其最后一个Iteration时样本个数为7,小于既定Batchsize。

torch.utils.data.Dataset”主要用于定义数据从哪里读取以及如何读取的问题。其定义为DataSet抽象类,所有自定义的Dataset都需要继承它,并复写“getitem()”内构函数,该函数接受一个索引,并返回一个样本。

在这里插入图片描述

DataLoader与DataSet数据读取机制

PyTorch的数据读取机制通常围绕三个核心问题展开,即:

读取哪些数据?从哪里读取数据?怎么读取数据?

事实上,通过在PyCharm中进行代码调试,我们可以简要回答上述问题:

1)通过Sampler取样器按序或随机挑选出Batchsize数量的索引列表;

2)使用DataSet中的data_dir指定硬盘上的数据访问路径;

3)使用DataSet中自定义的getitem()方法,基于Sampler返回的索引列表读取相应数据和标签,并拼接成新的列表数据。

在这里插入图片描述

事实上,PyTorch的数据读取经过了诸多函数的跳转。在for循环中首先调用了“DataLoader”,进而使用Sampler、Dataset和getitem解决“数据读哪些?从哪读?怎么读?”的问题。

最后,我们提供一份PyTorch中DataLoader数据读取机制的函数跳转流程图,供大家参考学习。

在这里插入图片描述

到此这篇关于浅谈PyTorch的数据读取机制Dataloader与Dataset的文章就介绍到这了,更多相关PyTorch的Dataloader与Dataset内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文