Pytorch 图像变换函数集合小结
作者:libo-coder
一、必要的 python 模块
PyTorch 的 Vision 模块提供了图像变换的很多函数.
torchvision/transforms/functional.py
from __future__ import division import torch import sys import math from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION try: import accimage except ImportError: accimage = None import numpy as np import numbers import collections import warnings import matplotlib as plt if sys.version_info < (3, 3): Sequence = collections.Sequence Iterable = collections.Iterable else: Sequence = collections.abc.Sequence Iterable = collections.abc.Iterable
以下图为例:
img_file = "test.jpe" img = Image.open(img_file) width, height = img.size #(750, 815) img.show()
二、PyTorch 图像变换函数
2.1 判断图像数据类型
# 图像格式检查,如,pil, tensor, numpy def _is_pil_image(img): if accimage is not None: return isinstance(img, (Image.Image, accimage.Image)) else: return isinstance(img, Image.Image) def _is_tensor_image(img): return torch.is_tensor(img) and img.ndimension() == 3 def _is_numpy_image(img): return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
# example: _is_pil_image(img) # True _is_tensor_image(img) # False _is_numpy_image(img) # False _is_numpy_image(np.array(img)) # True
2.2 to_tensor(pic)
将 PIL Image
或 nupy.ndarray
转换为 tensor
def to_tensor(pic): """ Args: pic (PIL Image or numpy.ndarray): Image to be converted to tensor. Returns: Tensor: Converted image. """ if not(_is_pil_image(pic) or _is_numpy_image(pic)): raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) if isinstance(pic, np.ndarray): # handle numpy array img = torch.from_numpy(pic.transpose((2, 0, 1))) # backward compatibility if isinstance(img, torch.ByteTensor): return img.float().div(255) else: return img if accimage is not None and isinstance(pic, accimage.Image): nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) pic.copyto(nppic) return torch.from_numpy(nppic) # handle PIL Image if pic.mode == 'I': img = torch.from_numpy(np.array(pic, np.int32, copy=False)) elif pic.mode == 'I;16': img = torch.from_numpy(np.array(pic, np.int16, copy=False)) elif pic.mode == 'F': img = torch.from_numpy(np.array(pic, np.float32, copy=False)) elif pic.mode == '1': img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) else: img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK if pic.mode == 'YCbCr': nchannel = 3 elif pic.mode == 'I;16': nchannel = 1 else: nchannel = len(pic.mode) img = img.view(pic.size[1], pic.size[0], nchannel) # put it from HWC to CHW format # yikes, this transpose takes 80% of the loading time/CPU img = img.transpose(0, 1).transpose(0, 2).contiguous() if isinstance(img, torch.ByteTensor): return img.float().div(255) else: return img
2.3 to_pil_image(pic, mode=None)
将 tensor
或 ndarray
转换为 PIL Image
def to_pil_image(pic, mode=None): """ Args: pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes Returns: PIL Image: Image converted to PIL Image. """ if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) elif isinstance(pic, torch.Tensor): if pic.ndimension() not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} '\ 'dimensions.'.format(pic.ndimension())) elif pic.ndimension() == 2: # if 2D image, add channel dimension (CHW) pic.unsqueeze_(0) elif isinstance(pic, np.ndarray): if pic.ndim not in {2, 3}: raise ValueError('pic should be 2/3 dimensional. Got {} '\ 'dimensions.'.format(pic.ndim)) elif pic.ndim == 2: # if 2D image, add channel dimension (HWC) pic = np.expand_dims(pic, 2) npimg = pic if isinstance(pic, torch.FloatTensor): pic = pic.mul(255).byte() if isinstance(pic, torch.Tensor): npimg = np.transpose(pic.numpy(), (1, 2, 0)) if not isinstance(npimg, np.ndarray): raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 'not {}'.format(type(npimg))) if npimg.shape[2] == 1: expected_mode = None npimg = npimg[:, :, 0] if npimg.dtype == np.uint8: expected_mode = 'L' elif npimg.dtype == np.int16: expected_mode = 'I;16' elif npimg.dtype == np.int32: expected_mode = 'I' elif npimg.dtype == np.float32: expected_mode = 'F' if mode is not None and mode != expected_mode: raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" .format(mode, np.dtype, expected_mode)) mode = expected_mode elif npimg.shape[2] == 4: permitted_4_channel_modes = ['RGBA', 'CMYK'] if mode is not None and mode not in permitted_4_channel_modes: raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) if mode is None and npimg.dtype == np.uint8: mode = 'RGBA' else: permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] if mode is not None and mode not in permitted_3_channel_modes: raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) if mode is None and npimg.dtype == np.uint8: mode = 'RGB' if mode is None: raise TypeError('Input type {} is not supported'.format(npimg.dtype)) return Image.fromarray(npimg, mode=mode)
2.4 normalize(tensor, mean, std)
归一化 tensor
的图像. in-place
计算.
def normalize(tensor, mean, std): """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channely. Returns: Tensor: Normalized Tensor image. """ if not _is_tensor_image(tensor): raise TypeError('tensor is not a torch image.') # This is faster than using broadcasting, don't change without benchmarking for t, m, s in zip(tensor, mean, std): t.sub_(m).div_(s) return tensor
# example mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] img_normalize = normalize(img_tensor, mean, std) # vis ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(to_pil_image(img_normalize)) ax2.axis("off") ax2.set_title("normalize img") plt.show()
2.5 resize(img, size, interpolation=Image.BILINEAR)
对输入的 PIL Image 进行 resize 到给定尺寸.
参数 size 为调整后的尺寸.
如果 size 是数组(h, w),则直接调整到该 (h, w) 尺寸.
如果 size 是一个 int 值,则调整后图像的最短边是该值,且保持固定的长宽比.
def resize(img, size, interpolation=Image.BILINEAR): """ Args: img (PIL Image): Image to be resized. size (sequence or int): Desired output size. interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR`` Returns: PIL Image: Resized image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): raise TypeError('Got inappropriate size arg: {}'.format(size)) if isinstance(size, int): w, h = img.size if (w <= h and w == size) or (h <= w and h == size): return img if w < h: ow = size oh = int(size * h / w) return img.resize((ow, oh), interpolation) else: oh = size ow = int(size * w / h) return img.resize((ow, oh), interpolation) else: return img.resize(size[::-1], interpolation)
# example: img_resize_256x256 = resize(img, (256, 256)) # (256, 256) img_resize_256 = resize(img, 256) # (256, 278) # vis ax1 = plt.subplot(1, 3, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 3, 2) ax2.imshow(img_resize_256x256) ax2.axis("off") ax2.set_title("resize_256x256 img") ax3 = plt.subplot(1, 3, 3) ax3.imshow(img_resize_256) ax3.axis("off") ax3.set_title("resize_256 img") plt.show()
2.6 pad(img, padding, fill=0, padding_mode=‘constant')
根据指定的 padding
模式和填充值,对给定的 PIL Image
的所有边进行 pad
处理.
参数 padding - int 或 tuple 形式.
padding:
- 如果是 int 值 ,则对所有的边都 padding 该 int 值.
- 如果是长度为 2 的tuple,则对 left/right 和 top/bottom 分别进行 padding.
- 如果是长度为 4 的 tuple,则对 left,top,right, bottom 边分别进行 padding.
参数 fill - 像素填充值,默认为 0. 如果值是长度为 3 的 tuple,则分别对 R,G,B 通道进行填充. 仅用于当 padding_mode='constant'
的情况.
参数 padding_mode - 填充的类型,可选:constant,edge,reflect,symmetric. 默认为 constant. 填充常数值.
constant - padding 填充常数值 fill.
edge - padding 图像边缘的最后一个值.
reflect - padding 图像的反射(reflection)值,(不对图像边缘的最后一个像素值进行重复)
如,[1, 2, 3, 4] 在 reflect 模式下在 两边 padding 2 个元素值,会得到:
[3, 2, 1, 2, 3, 4, 3, 2]
symmetric - padding 图像的反射(reflection)值,(对图像边缘的最后一个像素值进行重复).
如,[1, 2, 3, 4] 在 symmetric 模式下在 两边 padding 2 个元素值,会得到:
[2, 1, 1, 2, 3, 4, 4, 3]
def pad(img, padding, fill=0, padding_mode='constant'): """ Args: img (PIL Image): Image to be padded. padding (int or tuple): Padding on each border. fill: Pixel fill value for constant fill. Default is 0. padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. Returns: PIL Image: Padded image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if not isinstance(padding, (numbers.Number, tuple)): raise TypeError('Got inappropriate padding arg') if not isinstance(fill, (numbers.Number, str, tuple)): raise TypeError('Got inappropriate fill arg') if not isinstance(padding_mode, str): raise TypeError('Got inappropriate padding_mode arg') if isinstance(padding, Sequence) and len(padding) not in [2, 4]: raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding))) assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ 'Padding mode should be either constant, edge, reflect or symmetric' if padding_mode == 'constant': if img.mode == 'P': palette = img.getpalette() image = ImageOps.expand(img, border=padding, fill=fill) image.putpalette(palette) return image return ImageOps.expand(img, border=padding, fill=fill) else: if isinstance(padding, int): pad_left = pad_right = pad_top = pad_bottom = padding if isinstance(padding, Sequence) and len(padding) == 2: pad_left = pad_right = padding[0] pad_top = pad_bottom = padding[1] if isinstance(padding, Sequence) and len(padding) == 4: pad_left = padding[0] pad_top = padding[1] pad_right = padding[2] pad_bottom = padding[3] if img.mode == 'P': palette = img.getpalette() img = np.asarray(img) img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) img = Image.fromarray(img) img.putpalette(palette) return img img = np.asarray(img) # RGB image if len(img.shape) == 3: img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) # Grayscale image if len(img.shape) == 2: img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) return Image.fromarray(img)
# example: img_padding = pad(img, (10, 20, 30 ,40), fill=128) # (750, 815) -> (790, 875) # vis ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_padding) ax2.axis("off") ax2.set_title("padding img") plt.show()
2.7 crop(img, i, j, h, w)
裁剪给定的 PIL Image.
def crop(img, i, j, h, w): """ Args: img (PIL Image): Image to be cropped. i: Upper pixel coordinate. j: Left pixel coordinate. h: Height of the cropped image. w: Width of the cropped image. Returns: PIL Image: Cropped image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.crop((j, i, j + w, i + h))
# example img_crop = crop(img, 100, 100, 500, 500) # (750, 815) -> (500, 500) ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_crop) ax2.axis("off") ax2.set_title("crop img") plt.show()
2.8 center_crop(img, output_size)
def center_crop(img, output_size): if isinstance(output_size, numbers.Number): output_size = (int(output_size), int(output_size)) w, h = img.size th, tw = output_size i = int(round((h - th) / 2.)) j = int(round((w - tw) / 2.)) return crop(img, i, j, th, tw)
#example img_centercrop = center_crop(img, (256, 256)) # (750, 815) -> (256, 256) ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_centercrop) ax2.axis("off") ax2.set_title("centercrop img") plt.show()
2.9 resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR)
对给定 PIL Image 进行裁剪,并 resize 到特定尺寸.
def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): """ Args: img (PIL Image): Image to be cropped. i: Upper pixel coordinate. j: Left pixel coordinate. h: Height of the cropped image. w: Width of the cropped image. size (sequence or int): Desired output size. Same semantics as ``resize``. interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``. Returns: PIL Image: Cropped image. """ assert _is_pil_image(img), 'img should be PIL Image' img = crop(img, i, j, h, w) img = resize(img, size, interpolation) return img
# example img_resizedcrop = resized_crop(img, 100, 100, 500, 500, (256, 256)) # (750, 815) -> (500, 500) -> (256, 256) ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_resizedcrop) ax2.axis("off") ax2.set_title("resizedcrop img") plt.show()
2.10 hflip(img)
水平翻转 (Horizontally flip) 给定的 PIL Image.
def hflip(img): """ Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Horizontall flipped image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.transpose(Image.FLIP_LEFT_RIGHT)
2.11 vflip(img)
垂直翻转 (Vertically flip) 给定的 PIL Image.
def vflip(img): """ Args: img (PIL Image): Image to be flipped. Returns: PIL Image: Vertically flipped image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.transpose(Image.FLIP_TOP_BOTTOM)
# example: img_hflip = hflip(img) img_vflip = vflip(img) ax1 = plt.subplot(1, 3, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 3, 2) ax2.imshow(img_hflip) ax2.axis("off") ax2.set_title("hflip img") ax3 = plt.subplot(1, 3, 3) ax3.imshow(img_vflip) ax3.axis("off") ax3.set_title("vflip img") plt.show()
2.12 five_crop(img, size)
Crop the given PIL Image into four corners and the central crop.
从给定 PIL Image 的四个角和中间裁剪出五个子图像.
def five_crop(img, size): """ Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. Returns: tuple: tuple (tl, tr, bl, br, center) Corresponding top left, top right, bottom left, bottom right and center crop. """ if isinstance(size, numbers.Number): size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." w, h = img.size crop_h, crop_w = size if crop_w > w or crop_h > h: raise ValueError("Requested crop size {} is bigger than input size {}".format(size, (h, w))) tl = img.crop((0, 0, crop_w, crop_h)) tr = img.crop((w - crop_w, 0, w, crop_h)) bl = img.crop((0, h - crop_h, crop_w, h)) br = img.crop((w - crop_w, h - crop_h, w, h)) center = center_crop(img, (crop_h, crop_w)) return (tl, tr, bl, br, center)
# example: img_tl, img_tr, img_bl, img_br, img_center = five_crop(img, (400, 400)) ax1 = plt.subplot(2, 3, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(2, 3, 2) ax2.imshow(img_tl) ax2.axis("off") ax2.set_title("tl img") ax3 = plt.subplot(2, 3, 3) ax3.imshow(img_tr) ax3.axis("off") ax3.set_title("tr img") ax4 = plt.subplot(2, 3, 4) ax4.imshow(img_bl) ax4.axis("off") ax4.set_title("bl img") ax5 = plt.subplot(2, 3, 5) ax5.imshow(img_br) ax5.axis("off") ax5.set_title("br img") ax6 = plt.subplot(2, 3, 6) ax6.imshow(img_center) ax6.axis("off") ax6.set_title("center img") plt.show()
2.13 ten_crop(img, size, vertical_flip=False)
将给定 PIL Image 裁剪出的四个角和中间部分的五个子图像,每个子图像进行翻转处理. 默认时水平翻转.
def ten_crop(img, size, vertical_flip=False): """ Args: size (sequence or int): Desired output size of the crop. If size is an int instead of sequence like (h, w), a square crop (size, size) is made. vertical_flip (bool): Use vertical flipping instead of horizontal Returns: tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) Corresponding top left, top right, bottom left, bottom right and center crop and same for the flipped image. """ if isinstance(size, numbers.Number): size = (int(size), int(size)) else: assert len(size) == 2, "Please provide only two dimensions (h, w) for size." first_five = five_crop(img, size) if vertical_flip: img = vflip(img) else: img = hflip(img) second_five = five_crop(img, size) return first_five + second_five
2.14 adjust_brightness(img, brightness_factor)
def adjust_brightness(img, brightness_factor): """ Args: img (PIL Image): PIL Image to be adjusted. brightness_factor (float): How much to adjust the brightness. Can be any non negative number. 0 gives a black image, 1 gives the original image, 2 increases the brightness by a factor of 2. Returns: PIL Image: Brightness adjusted image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) enhancer = ImageEnhance.Brightness(img) img = enhancer.enhance(brightness_factor) return img
# example: img_adjust_brightness = adjust_brightness(img, 2.5) # vis ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_adjust_brightness) ax2.axis("off") ax2.set_title("adjust_brightness img") plt.show()
2.15 adjust_contrast(img, contrast_factor)
调整对比度.
def adjust_contrast(img, contrast_factor): """ Args: img (PIL Image): PIL Image to be adjusted. contrast_factor (float): How much to adjust the contrast. Can be any non negative number. 0 gives a solid gray image, 1 gives the original image, 2 increases the contrast by a factor of 2. Returns: PIL Image: Contrast adjusted image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) enhancer = ImageEnhance.Contrast(img) img = enhancer.enhance(contrast_factor) return img
# example: img_adjust_contrast = adjust_contrast(img, 2.5) # vis ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_adjust_contrast) ax2.axis("off") ax2.set_title("adjust_contrast img") plt.show()
2.16 adjust_saturation(img, saturation_factor)
调整颜色饱和度.
def adjust_saturation(img, saturation_factor): """ Args: img (PIL Image): PIL Image to be adjusted. saturation_factor (float): How much to adjust the saturation. 0 will give a black and white image, 1 will give the original image while 2 will enhance the saturation by a factor of 2. Returns: PIL Image: Saturation adjusted image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) enhancer = ImageEnhance.Color(img) img = enhancer.enhance(saturation_factor) return img
# example img_adjust_saturation = adjust_saturation(img, 2.5) # vis ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_adjust_saturation) ax2.axis("off") ax2.set_title("adjust_saturation img") plt.show()
2.17 adjust_hue(img, hue_factor)
调整图像 HUE.
通过将图像转换为 HSV 空间,并周期地移动在 hue 通道(H) 的强度,以实现图像 hue 的调整.
最后,再将结果转换回原始的图像模式.参数 hue_factor - H 通道平移的因子,其值必须在区间 [-0.5, 0.5].
def adjust_hue(img, hue_factor): """ Args: img (PIL Image): PIL Image to be adjusted. hue_factor (float): How much to shift the hue channel. Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in HSV space in positive and negative direction respectively. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image with complementary colors while 0 gives the original image. Returns: PIL Image: Hue adjusted image. """ if not(-0.5 <= hue_factor <= 0.5): raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) input_mode = img.mode if input_mode in {'L', '1', 'I', 'F'}: return img h, s, v = img.convert('HSV').split() np_h = np.array(h, dtype=np.uint8) # uint8 addition take cares of rotation across boundaries with np.errstate(over='ignore'): np_h += np.uint8(hue_factor * 255) h = Image.fromarray(np_h, 'L') img = Image.merge('HSV', (h, s, v)).convert(input_mode) return img
# example: img_adjust_hue = adjust_hue(img, 0.5) # vis ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_adjust_hue) ax2.axis("off") ax2.set_title("adjust_hue img") plt.show()
2.18 adjust_gamma(img, gamma, gain=1)
对图像进行伽马校正(gamma correction). 也被叫作 Power Law Transform.
def adjust_gamma(img, gamma, gain=1): """ Args: img (PIL Image): PIL Image to be adjusted. gamma (float): Non negative real number, 如公式中的 \gamma 值. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter. gain (float): The constant multiplier. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if gamma < 0: raise ValueError('Gamma should be a non-negative real number') input_mode = img.mode img = img.convert('RGB') gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 img = img.point(gamma_map) # use PIL's point-function to accelerate this part img = img.convert(input_mode) return img
# example: img_adjust_gamma = adjust_gamma(img, 0.5) # vis ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_adjust_gamma) ax2.axis("off") ax2.set_title("adjust_gamma img") plt.show()
2.19 rotate(img, angle, resample=False, expand=False, center=None)
旋转图像.
参数 resample
可选值:PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC.
如果参数 resample
被忽略,或图像的模式是 1 或 P,则resample=PIL.Image.NEAREST.
参数 expand
如果 expand=True,则延展输出图像,以能包含旋转后的全部图像.
如果 expand=False 或被忽略,则保持输出图像与输入图像的尺寸一致.
expand 假设旋转是以中心进行旋转,且没有平移.
def rotate(img, angle, resample=False, expand=False, center=None): """ Args: img (PIL Image): PIL Image to be rotated. angle (float or int): In degrees degrees counter clockwise order. resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): expand (bool, optional): Optional expansion flag. center (2-tuple, optional): Optional center of rotation. Origin is the upper left corner. Default is the center of the image. """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return img.rotate(angle, resample, expand, center)
# example: img_rotate = rotate(img, 60) # vis ax1 = plt.subplot(1, 2, 1) ax1.imshow(img) ax1.axis("off") ax1.set_title("orig img") ax2 = plt.subplot(1, 2, 2) ax2.imshow(img_rotate) ax2.axis("off") ax2.set_title("rotate img") plt.show()
2.20 affine(img, angle, translate, scale, shear, resample=0, fillcolor=None)
保持图像中心不变,进行仿射变换.
def _get_inverse_affine_matrix(center, angle, translate, scale, shear): # Helper method to compute inverse matrix for affine transformation # As it is explained in PIL.Image.rotate # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] # RSS is rotation with scale and shear matrix # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0] # [ sin(a)*scale cos(a + shear)*scale 0] # [ 0 0 1] # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 angle = math.radians(angle) shear = math.radians(shear) scale = 1.0 / scale # Inverted rotation matrix with scale and shear d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) matrix = [ math.cos(angle + shear), math.sin(angle + shear), 0, -math.sin(angle), math.cos(angle), 0 ] matrix = [scale / d * m for m in matrix] # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) # Apply center translation: C * RSS^-1 * C^-1 * T^-1 matrix[2] += center[0] matrix[5] += center[1] return matrix def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): """ Args: img (PIL Image): PIL Image to be rotated. angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) scale (float): overall scale shear (float): shear angle value in degrees between -180 to 180, clockwise direction. resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ "Argument translate should be a list or tuple of length 2" assert scale > 0.0, "Argument scale should be positive" output_size = img.size center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] == '5' else {} return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
2.21 to_grayscale(img, num_output_channels=1)
将图像转换为灰度图.
def to_grayscale(img, num_output_channels=1): """ Args: img (PIL Image): Image to be converted to grayscale. Returns: PIL Image: Grayscale version of the image. if num_output_channels = 1 : returned image is single channel if num_output_channels = 3 : returned image is 3 channel with r = g = b """ if not _is_pil_image(img): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) if num_output_channels == 1: img = img.convert('L') elif num_output_channels == 3: img = img.convert('L') np_img = np.array(img, dtype=np.uint8) np_img = np.dstack([np_img, np_img, np_img]) img = Image.fromarray(np_img, 'RGB') else: raise ValueError('num_output_channels should be either 1 or 3') return img
参考链接
https://www.aiuai.cn/aifarm759.html
到此这篇关于Pytorch 图像变换函数集合小结的文章就介绍到这了,更多相关Pytorch 图像变换函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!