Pytorch使用DataLoader实现批量加载数据
作者:Vic·Tory
在进行模型训练时,需要把数据按照固定的形式分批次投喂给模型,在PyTorch中通过torch.utils.data库的DataLoader
完成分批次返回数据。
构造DataLoader首先需要一个Dataset
数据源,Dataset完成数据的读取并可以返回单个数据,然后DataLoader在此基础上完成数据清洗、打乱等操作并按批次返回数据。
Dataset
PyTorch将数据源分为两种类型:类似Map型(Map-style datasets)和可迭代型(Iterable-style datasets)。
Map风格的数据源可以通过索引idx对数据进行查找:dataset[idx]
,它需要继承Dataset
类,并且重写__getitem__()
方法完成根据索引值获取数据和__len__()
方法返回数据的总长度。
可迭代型可以迭代获取其数据,但没有固定的长度,因此也不能通过下标获得数据,通常用于无法获取全部数据或者流式返回的数据。它继承自IterableDataset
类,并且需要实现__iter__()
方法来完成对数据集的迭代和返回。
如下所示为自定义的数据源MySet
,它完成数据的读取,这里假定为[1, 9] 9个数据,然后重写了__getitem__() 和__len__() 方法
from torch.utils.data import Dataset, DataLoader, Sampler class MySet(Dataset): # 读取数据 def __init__(self): self.data = [1, 2, 3, 4, 5, 6, 7, 8, 9] # 根据索引返回数据 def __getitem__(self, idx): return self.data[idx] # 返回数据集总长度 def __len__(self): return len(self.data)
DataLoader
其构造函数如下:
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
dataset
:Dataset类型,从其中加载数据 batch_size:int,可选。每个batch加载多少样本batch_size
: 一个批次的数据个数shuffle
:bool,可选。为True时表示每个epoch都对数据进行洗牌sampler
:Sampler,可选。获取下一个数据的方法。batch_sampler
:获取下一批次数据的方法num_workers
:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。collate_fn
:callable,可选,自定义处理数据并返回。pin_memory
:bool,可选,True代表将数据Tensor放入CUDA的pin储存drop_last
:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。
Sampler索引
既然DataLoader根据索引值从Dataset中获取数据,那么如何获取一个批次数据的索引,索引值应该如何排列才能实现随机的效果?这就需要Sampler
了,它可以对索引进行shuffle操作来打乱顺序,并且根据batch size一次返回指定个数的索引序列。
在初始化DataLoader时通过sampler
属性指定获取下一个数据的索引的方法,或者batch_sampler
属性指定获取下一个批次数据的索引。
当我们设置DataLoader的shuffle
属性为True时,会根据batch_size
属性传入的批次大小自动构造sample返回下一个批次的索引。
当我们不启用shuffle属性时,就可以通过batch_sampler
属性自定义sample来返回下一批的索引,注意这时候不可用使用 batch_size
, shuffle
, sampler
, 和drop_last
属性。
如下所示为自定义MySampler
,它继承自Sampler
,由传入dataset
的长度产生对应的索引,例如上面有9个数据,那么产生索引[0, 8]。
根据批次大小batch_size
计算出总批次数,例如当batchsize是3,那么9/3=3,即总共有3个批次。
重写__iter__()
方法按批次返回索引,即第一批返回[0, 1, 2],第二批返回[3, 4, 5]以此类推。
__len__()
方法返回总的批次数,即3个批次。
class MySampler(Sampler): def __init__(self, dataset, batchsize): super(Sampler, self).__init__() self.dataset = dataset self.batch_size = batchsize # 每一批数据量 self.indices = range(len(dataset)) # 生成数据集的索引 self.count = int(len(dataset) / self.batch_size) # 一共有多少批 def __iter__(self): for i in range(self.count): yield self.indices[i * self.batch_size: (i + 1) * self.batch_size] def __len__(self): return self.count
collate处理数据
当我们拿到数据如果希望进行一些预处理而不是直接返回,这时候就需要collate_fn属性来指定处理和返回数据的方法,如果不指定该属性,默认会将普通的NumPy数组转换为PyTorch的tensor并直接返回。
如下所示为自定义的my_collate()
函数,默认传入获得的一个批次的数据data,例如之前返回一批数据[1, 2, 3],这里遍历数据并平方之后放在res数组中返回[1, 4, 9]
def my_collate(data): res = [] for d in data: res.append(d ** 2) return res
有了上面的索引获取类MySampler
和数据处理函数my_collate()
,就可以使用DataLoader自定义获取批数据了。
首先DataLoader通过my_sampler
返回的索引[0, 1, 2]去dataset
拿到数据[1, 2, 3],然后传递给my_collate进行平方操作,然后返回一个批次的结果为[1, 4, 9],一共有三个批次的数据。
dataset = MySet() # 定义数据集 my_sampler = MySampler(dataset, 3) # 实例化MySampler data_loader = DataLoader(dataset, batch_sampler=my_sampler, collate_fn=my_collate) for data in data_loader: # 按批次获取数据 print(data) ''' [1, 4, 9] [16, 25, 36] [49, 64, 81] '''
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。