python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch的dataloader使用

pytorch中的dataloader使用方法详解

作者:惊瑟

这篇文章主要介绍了pytorch中的dataloader使用方法详解,构建自己的dataloader是模型训练的第一步,本篇文章介绍下pytorch与dataloader以及与其相关的类的用法,需要的朋友可以参考下

pytorch中的dataloader使用方法详解

DataLoader类中有一个必填参数为dataset,因此在构建自己的dataloader前,先要定义好自己的Dataset类。这里先大致介绍下这两个类的作用:

有小伙伴可能会疑惑,自己写一个加载数据的工具似乎也没有多“困难”,为何大费周章要继承pytorch中类,按照它的规则加载数据呢?

总结一下就是:

使用方法

两步走:

  1. 定义自己的Dataset类,具体要做的事:
    • 告诉它去哪儿读数据,并将数据resize为统一的shape(可以思考下为什么呢)
    • 重写__len__()以及__getitem__,其中__getitem__中要确定自己想要哪些数据,然后将其return出来。
  2. 将自己的Dataset实例传到Dataloder中并设置想要的参数,构建自己的dataloader

下面简单加载一个目录下的图片以及label:

import os
import numpy as np

from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import cv2

# Your Data Path
img_dir = '/home/jyz/Downloads/classify_example/val/骏马/'
anno_file = '/home/jyz/Downloads/classify_example/val/label.txt'


class MyDataset(Dataset):
    def __init__(self, img_dir, anno_file, imgsz=(640, 640)):
        self.img_dir = img_dir
        self.anno_file = anno_file
        self.imgsz = imgsz
        self.img_namelst = os.listdir(self.img_dir)

    # need to overload
    def __len__(self):
        return len(self.img_namelst)

    # need to overload
    def __getitem__(self, idx):
        with open(self.anno_file, 'r') as f:
            label = f.readline().strip()
        img = cv2.imread(os.path.join(img_dir, self.img_namelst[idx]))
        img = cv2.resize(img, self.imgsz)
        return img, label


dataset = MyDataset(img_dir, anno_file)
dataloader = DataLoader(dataset=dataset, batch_size=2)

# display
for img_batch, label_batch in dataloader:
    img_batch = img_batch.numpy()
    print(img_batch.shape)
    # img = np.concatenate(img_batch, axis=0)
    if img_batch.shape[0] == 2:
        img = np.hstack((img_batch[0], img_batch[1]))
    else:
        img = np.squeeze(img_batch, axis=0)  # 最后一张图时,删除第一个维度
    print(img.shape)
    cv2.imshow(label_batch[0], img)
    cv2.waitKey(0)

上面是一次加载两张图片,效果如下:

在这里插入图片描述

其实从这里可以看出,为什么要在Dataset中将数据resize为统一的shape。因为dataloader加载数据时,将一个batch_size的数据拼接成一个大的tensor,如果shape不同,就无法拼接了。

就像这两张图片加入shape不一样就无法通过拼接的方式show出来一样。

结论

到此这篇关于pytorch中的dataloader使用方法详解的文章就介绍到这了,更多相关pytorch的dataloader使用内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文