在Pytorch中自定义dataset读取数据的实现代码
作者:Kelly_Ai_Bai
这里使用的是经典的花分类数据集
下载地址:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
下载结束后进行解压,可以得到五种不同种类花的图片,如上图所示
主函数 main
def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root) data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), "val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])} train_data_set = MyDataSet(images_path=train_images_path, images_class=train_images_label, transform=data_transform["train"]) batch_size = 8 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers'.format(nw)) train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True, num_workers=nw, collate_fn=train_data_set.collate_fn) # plot_data_loader_image(train_loader) for step, data in enumerate(train_loader): images, labels = data
其中,
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data ( root )
传入参数 root(就是该数据集所在的路径),没有传入参数val_rate就取其默认值0.2( 即验证集占整个数据集的 20% ), 调用函数 read_split_data
def read_split_data(root: str, val_rate: float = 0.2): random.seed(0) # 保证随机结果可复现 assert os.path.exists(root), "dataset root: {} does not exist.".format(root) # 遍历文件夹,一个文件夹对应一个类别 flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] # 排序,保证顺序一致 flower_class.sort() # 生成类别名称以及对应的数字索引 class_indices = dict((k, v) for v, k in enumerate(flower_class)) json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) train_images_path = [] # 存储训练集的所有图片路径 train_images_label = [] # 存储训练集图片对应索引信息 val_images_path = [] # 存储验证集的所有图片路径 val_images_label = [] # 存储验证集图片对应索引信息 every_class_num = [] # 存储每个类别的样本总数 supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 # 遍历每个文件夹下的文件 for cla in flower_class: cla_path = os.path.join(root, cla) # 遍历获取supported支持的所有文件路径 images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported] # 获取该类别对应的索引 image_class = class_indices[cla] # 记录该类别的样本数量 every_class_num.append(len(images)) # 按比例随机采样验证样本 val_path = random.sample(images, k=int(len(images) * val_rate)) for img_path in images: if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集 val_images_path.append(img_path) val_images_label.append(image_class) else: # 否则存入训练集 train_images_path.append(img_path) train_images_label.append(image_class) print("{} images were found in the dataset.".format(sum(every_class_num))) print("{} images for training.".format(len(train_images_path))) print("{} images for validation.".format(len(val_images_path))) plot_image = False if plot_image: # 绘制每种类别个数柱状图 plt.bar(range(len(flower_class)), every_class_num, align='center') # 将横坐标0,1,2,3,4替换为相应的类别名称 plt.xticks(range(len(flower_class)), flower_class) # 在柱状图上添加数值标签 for i, v in enumerate(every_class_num): plt.text(x=i, y=v + 5, s=str(v), ha='center') # 设置x坐标 plt.xlabel('image class') # 设置y坐标 plt.ylabel('number of images') # 设置柱状图的标题 plt.title('flower class distribution') plt.show() return train_images_path, train_images_label, val_images_path, val_images_label
运行上述代码, 得到 class_indices.json 文件,该文件存储了类别名称以及每个类别对应的索引
设置变量 plot_image 为True,可以将每个类别的样本数以柱状图的形式可视化出来
函数 read_split_data 执行结束后,返回四个列表 : train_images_path 、train_images_label 、val_images_path 和 val_images_label,分别表示训练集的图像和标签路径以及验证集的图像和标签路径,对数据集完成了训练集和验证集的划分!
然后对训练集和验证集中的数据进行数据预处理,比如裁剪、翻转、归一化等等操作
接下来,重点来了!
train_data_set = MyDataSet(images_path=train_images_path, images_class=train_images_label, transform=data_transform["train"])
传入训练集图像的路径列表、标签列表以及数据预处理的方法,对类 MyDataSet 进行初始化,得到类 MyDataSet 的实例对象 train_data_set
MyDataSet 是一个自定义的数据类,代码如下:
from PIL import Image import torch from torch.utils.data import Dataset class MyDataSet(Dataset): """自定义数据集""" def __init__(self, images_path: list, images_class: list, transform=None): self.images_path = images_path self.images_class = images_class self.transform = transform def __len__(self): return len(self.images_path) def __getitem__(self, item): img = Image.open(self.images_path[item]) # RGB为彩色图片,L为灰度图片 if img.mode != 'RGB': raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) label = self.images_class[item] if self.transform is not None: img = self.transform(img) return img, label @staticmethod def collate_fn(batch): # 官方实现的default_collate可以参考 # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py images, labels = tuple(zip(*batch)) images = torch.stack(images, dim=0) labels = torch.as_tensor(labels) return images, labels
该类继承类Dataset,主要实现初始化函数__init__( )、计算数据集中样本数量的函数__len__( )、根据索引返回相应的图片和标签的函数__getitem__( ) 以及 collate_fn( ) 函数
我想要重点阐述一下关于函数 collate_fn( ) 函数的作用
collate_fn( ) 函数决定了如何将数据进行打包处理
@staticmethod def collate_fn(batch): # 官方实现的default_collate可以参考 # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py images, labels = tuple(zip(*batch)) images = torch.stack(images, dim=0) labels = torch.as_tensor(labels) return images, labels
传入函数的参数 batch 是由 (images,labels) 组成的一个个的元组
如果在此处设置batch_size的值为8,那么这个函数就从数据集中获取8张图片以及这8张图片所对应的标签
可以设置断点来看一下:
因为 batch_size 取 8,所以可以看到 batch 是一个长度为8的列表,列表是由8个元组元素组成的,每个元组是由图像和其所对应的标签组成的
最后,通过 DataLoader 从实例化对象 train_data_set 中加载数据,打包成一个一个 batch 送入网络中进行训练
train_loader = torch.utils.data.DataLoader(train_data_set, batch_size=batch_size, shuffle=True, num_workers=nw, collate_fn=train_data_set.collate_fn)
这样就可以得到用于加载训练数据的数据加载器 train_loader
可以将 数据加载器 train_loader 传给函数,通过调用函数 plot_data_loader_image 后
plot_data_loader_image(train_loader)
这样就能可视化出数据加载器 train_loader 中的内容,如图所示(此处需要将 num_workers 设置为0)
以上就是在Pytorch中自定义dataset读取数据的实现代码的详细内容,更多关于在Pytorch自定义dataset的资料请关注脚本之家其它相关文章!