python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > Dataset和Dataloader详解

Python中的Dataset和Dataloader详解

作者:菜菜01

这篇文章主要介绍了Python中的Dataset和Dataloader详解,DataLoader与DataSet是PyTorch数据读取的核心,是构建一个可迭代的数据装载器,每次执行循环的时候,就从中读取一批Batchsize大小的样本进行训练,需要的朋友可以参考下

Dataset,Dataloader是什么?

为什么要了解Dataloader

​ 因为你的神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。

因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。

​ 通常,我们在默认或知名数据集(如 MNIST 或 CIFAR)上训练神经网络,可以轻松地实现预测和分类类型问题的超过 90% 的准确度。

但是那是因为这些数据集组织整齐且易于预处理。

但是处理自己的数据集时,我们常常无法达到这样高的准确率

Dataloader 的使用

载入相关类

from torch.utils.data import Dataloader

设置相关参数

from torch.utils.data import DataLoader
DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )
"""
dataset:是数据集
batch_size:是指一次迭代中使用的训练样本数。通常我们将数据分成训练集和测试集,并且我们可能有不同的批量大小。
shuffle:是传递给 DataLoader 类的另一个参数。该参数采用布尔值(真/假)。如果 shuffle 设置为 True,则所有样本都被打乱并分批加载。否则,它们会被一个接一个地发送,而不会进行任何洗牌。
num_workers:允许多处理来增加同时运行的进程数
collate_fn:合并数据集
pin_memory:锁页内存:将张量固定在内存中
"""

以minist为例子

# Import MNIST
from torchvision.datasets import MNIST
# Download and Save MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True)
# Print Data
print(data_train)
print(data_train[12])
#Dataset MNIST Number of datapoints: 60000 Root location: /Users/viharkurama/mnist_data Split: Train (<PIL.Image.Image image mode=L size=28x28 at 0x11164A100>, 3)

现在让尝试提取元组,其中第一个值对应于图像,第二个值对应于其各自的标签。

下面是代码片段:

import matplotlib.pyplot as plt
random_image = data_train[0][0]
random_image_label = data_train[0][1]
# Print the Image using Matplotlib
plt.imshow(random_image)
print("The label of the image is:", random_image_label)

让我们使用 DataLoader 类来加载数据集,如下所示。

import torch
from torchvision import transforms
data_train = torch.utils.data.DataLoader(
    MNIST(
          '~/mnist_data', train=True, download=True, 
          transform = transforms.Compose([
              transforms.ToTensor()
          ])),
          batch_size=64,
          shuffle=True
          )
for batch_idx, samples in enumerate(data_train):
      print(batch_idx, samples)

这就是我们使用 DataLoader 加载简单数据集的方式。 但是,我们不能总是对每个数据集都依赖已经有的数据集,要是自己的数据集怎么办

定义自己的数据集

我们将创建一个由数字和文本组成的简单自定义数据集

先介绍两个方法

#__getitem__() 方法通过索引返回数据集中选定的样本。
#__len__() 方法返回数据集的总大小。例如,如果您的数据集包含 1,00,000 个样本,则 len 方法应返回 1,00,000。
class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

​ 创建自定义数据集并不复杂,但作为加载数据的典型过程的附加步骤,有必要构建一个接口以获得良好的抽象(至少可以说是一个很好的语法糖)。

现在我们将创建一个包含数字及其平方值的新数据集。 让我们将数据集称为 SquareDataset。 其目的是返回 [a,b] 范围内的值的平方。

下面是相关代码:

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
class SquareDataset(Dataset):
     def __init__(self, a=0, b=1):
         super(Dataset, self).__init__()
         assert a <= b
         self.a = a
         self.b = b
     def __len__(self):
         return self.b - self.a + 1
     def __getitem__(self, index):
        assert self.a <= index <= self.b
        return index, index**2
data_train = SquareDataset(a=1,b=64)
data_train_loader = DataLoader(data_train, batch_size=64, shuffle=True)
print(len(data_train))

​ 在上面的代码块中,我们创建了一个名为 SquareDataset 的 Python 类,它继承了 PyTorch 的 Dataset 类。

接下来,我们调用了一个 init() 构造函数,其中 a 和 b 分别被初始化为 0 和 1。 超类用于从继承的 Dataset 类中访问 len 和 get_item 方法。

接下来我们使用 assert 语句来检查 a 是否小于或等于 b,因为我们想要创建一个数据集,其中值将位于 a 和 b 之间。

​ 然后,我们使用 SquareDataset 类创建了一个数据集,其中数据值的范围为 1 到 64。我们将其加载到名为 data_train 的变量中。

最后,Dataloader 类在 data_train_loader 中存储的数据上创建了一个迭代器,batch_size 初始化为 64,shuffle 设置为 True。

如何使用transform

​ 当你学会怎么定义自己的数据集的时候,你可能会想要更近 一步的操作,对于你自己的数据集进行剪切或者变换

​ 以CIFAR10为例子

导入必要的模块

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

接下来,我们将定义一个名为 transforms 的变量,我们在其中按顺序编写所有预处理步骤。我们使用 Compose 类将所有转换操作链接在一起。

transform = transforms.Compose([
    # resize
    transforms.Resize(32),
    # center-crop
    transforms.CenterCrop(32),
    # to-tensor
    transforms.ToTensor(),
    # normalize
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
"""
resize:此调整大小转换将所有图像转换为定义的大小。在这种情况下,我们要将所有图像的大小调整为 32×32。因此,我们将 32 作为参数传递。
center-crop:接下来,我们使用 CenterCrop 变换裁剪图像。 我们发送的参数也是分辨率/大小,但由于我们已经将图像大小调整为 32x32,因此图像将与此裁剪中心对齐。 这意味着图像将从中心裁剪 32 个单位(垂直和水平)。
to-tensor:我们使用 ToTensor() 方法将图像转换为张量数据类型。
normalize:这将张量中的所有值归一化,使它们位于 0.5 和 1 之间。
"""

在下一步中,在执行我们刚刚定义的转换之后,我们将使用 trainloader 将 CIFAR 数据集加载到训练集中。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=False)

到此这篇关于Python中的Dataset和Dataloader详解的文章就介绍到这了,更多相关Dataset和Dataloader详解内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文