python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch图片数据增强

基于pytorch实现对图片进行数据增强

作者:草莓橙子碗

图像数据增强是一种在训练机器学习和深度学习模型时常用的策略,尤其是在计算机视觉领域,具体而言,它通过创建和原始图像稍有不同的新图像来扩大训练集,本文给大家介绍了如何基于pytorch实现对图片进行数据增强,需要的朋友可以参考下

背景

在进行机器学习的任务中,我们的训练数据往往是有限的,在有限的数据集上获得较好的模型训练结果,我们不仅要在模型结构上下功夫,另一方面也需要对数据集进行数据增强

图片数据增强

图像数据增强是一种在训练机器学习和深度学习模型时常用的策略,尤其是在计算机视觉领域。具体而言,它通过创建和原始图像稍有不同的新图像来扩大训练集。 数据增强的主要目标有以下几点:

代码实现

我们使用torchvision的transforms库对图片数据进行数据增强,使用一张卡比巴拉的图片

首先读取图片数据,以下是准备工作

from PIL import Image
import numpy as np
import torchvision.transforms as tfs
import matplotlib.pyplot as plt
 
img_path = r"D:\CSDN_point\1_4\kabibala.jpg"
img = Image.open(img_path)
print("the shape of img is {}".format(np.array(img).shape))

图片伸缩

img_re = tfs.Resize((500,1000))(img)
plt.imshow(img_re)
plt.show()

tfs.Reszie((500,1000))把图像的高和宽分别拉伸到500像素和1000像素

图片裁剪

img_crop = tfs.RandomCrop(500)(img)
plt.imshow(img_crop)
plt.show()

tfs.RandomCrop(500)随机截取图片500\times500大小的区域

中心裁剪

img_crop_cen = tfs.CenterCrop(700)(img)
plt.imshow(img_crop_cen)
plt.show()

tfs.CenterCrop(700)裁剪图片中心位置700\times700大小的区域

随机水平翻转

# 随机水平翻转,概率是0.5
img_hori = tfs.RandomHorizontalFlip()(img)
# 随机垂直翻转,概率是0.5
img_ver = tfs.RandomVerticalFlip()(img)
 
plt.subplot(1,2,1)
plt.imshow(img_hori)
plt.title("RandomHorizontalFlip")
plt.subplot(1,2,2)
plt.imshow(img_ver)
plt.title("RandomVerticalFlip")
plt.show()

随机改变图片亮度、对比度和色相

img_j = tfs.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)(img)
plt.imshow(img_j)
plt.show()

tfs.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)中参数的含义:

多变换整合

可以使用Compose模块将多个变换进行整合

# 将多个变换整合到my_trans
my_trans = tfs.Compose([tfs.Resize(800),tfs.RandomHorizontalFlip(),tfs.RandomCrop(500),tfs.RandomVerticalFlip(),tfs.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)])
 
_,fig = plt.subplots(3,3)
for i in range(3):
    for j in range(3):
        fig[i][j].imshow(my_trans(img))
        plt.axis('off')
plt.show()

以上就是基于pytorch实现对图片进行数据增强的详细内容,更多关于pytorch图片数据增强的资料请关注脚本之家其它相关文章!

您可能感兴趣的文章:
阅读全文