transforms.Compose()函数的使用及说明
作者:HealthScience
这篇文章主要介绍了transforms.Compose()函数的使用及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
transforms.Compose()函数的使用
torchvision.transforms是pytorch中的图像预处理包。
一般用Compose把多个步骤整合到一起:
比如说:
transforms.Compose([ transforms.CenterCrop(10), transforms.ToTensor(), ])
这样就把两个步骤整合到了一起。
介绍transforms中的函数
Resize | 把给定的图片resize到given size |
Normalize | 用均值和标准差归一化张量图像 |
ToTensor | convert a PIL image to tensor (H*W*C) in range [0,255] to a torch.Tensor(C*H*W) in the range [0.0,1.0] |
CenterCrop | 在图片的中间区域进行裁剪 |
RandomCrop | 在一个随机的位置进行裁剪 |
FiceCrop | 把图像裁剪为四个角和一个中心 |
RandomResizedCrop | 将PIL图像裁剪成任意大小和纵横比 |
ToPILImage | convert a tensor to PIL image |
RandomHorizontalFlip | 以0.5的概率水平翻转给定的PIL图像 |
RandomVerticalFlip | 以0.5的概率竖直翻转给定的PIL图像 |
Grayscale | 将图像转换为灰度图像 |
RandomGrayscale | 将图像以一定的概率转换为灰度图像 |
ColorJitter | 随机改变图像的亮度对比度和饱和度 |
备注:Python图像库PIL(Python Image Library)是python的第三方图像处理库,但是由于其强大的功能与众多的使用人数,几乎已经被认为是python官方图像处理库了。
实例
经常看到
transform.ToTensor(), transform.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
那transform.Normalize()是怎么工作的呢?以上面代码为例,
- ToTensor()能够把灰度范围从0-255变换到0-1之间,
- 而后面的transform.Normalize()则把0-1变换到(-1,1).
具体地说,对每个通道而言,Normalize执行以下操作:image=(image-mean)/std其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定。
原来的0-1最小值0则变成(0-0.5)/0.5=-1,而最大值1则变成(1-0.5)/0.5=1.
transforms.Compose()类详解:串联多个transform操作
torchvision
是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision.transforms
主要是用于常见的一些图形变换。
以下是torchvision
的构成:
torchvision.datasets
: 一些加载数据的函数及常用的数据集接口;torchvision.models
: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;torchvision.transforms
: 常用的图片变换,例如裁剪、旋转等;torchvision.utils
: 其他的一些有用的方法。
本文的主题是其中的torchvision.transforms.Compose()
类。
这个类的主要作用是串联多个图片变换的操作。
这个类的构造很简单:
class torchvision.transforms.Compose(transforms): # Composes several transforms together. # Parameters: transforms (list of Transform objects) – list of transforms to compose. Example # 可以看出Compose里面的参数实际上就是个列表,而这个列表里面的元素就是你想要执行的transform操作。 >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.ToTensor(),])
事实上,Compose()
类会将transforms列表里面的transform操作进行遍历。
实现的代码很简单:
## 这里对源码进行了部分截取。 def __call__(self, img): for t in self.transforms: img = t(img) return img
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。