pytorch中的nn.Unfold()函数和fold()函数解读
作者:halo_wm
pytorch的nn.Unfold()函数和fold()函数
1.nn.Unfold()函数
描述:pytorch中的nn.Unfold()函数,在图像处理领域,经常需要用到卷积操作,但是有时我们只需要在图片上进行滑动的窗口操作,将图片切割成patch,而不需要进行卷积核和图片值的卷积乘法操作。
这是就需要用到nn.Unfold()函数,该函数是从一个batch图片中,提取出滑动的局部区域块,也就是卷积操作中的提取kernel filter对应的滑动窗口。
torch.nn.Unfold(kernel_size,dilation=1,paddding=0,stride=1)
该函数的输入是(bs,c,h,w),其中bs为batch-size,C是channel的个数。
而该函数的输出是(bs,Cxkernel_size[0]xkernel_size[1],L)其中L是特征图或者图片的尺寸根据kernel_size的长宽滑动裁剪后得到的多个patch的数量。
import torch.nn as nn import torch batches_img=torch.rand(1,2,4,4)#模拟图片数据(bs,2,4,4),通道数C为2 print("batches_img:\n",batches_img) nn_Unfold=nn.Unfold(kernel_size=(2,2),dilation=1,padding=0,stride=2) patche_img=nn_Unfold(batches_img) print("patche_img.shape:",patche_img.shape) print("patch_img:\n",patche_img)
该方法的主要应用场景是将图片切割成不同的patch,配合一下代码实现
#上面的代码能够获取到patch_img,(bs,C*K*K,L),L代表的是将每张图片分割成多少块 reshape_patche_img=patche_img.view(batches_img.shape[0],batches_img.shape[1],2,2,-1) print(reshape_patche_img.shape)#[bs, C, k, k, L] reshape_patche_img=reshape_patche_img.permute(0,4,1,2,3)#[N, L, C, k, k] print(reshape_patche_img.shape)
结果:
2.nn.Fold()函数
该函数是nn.Unfold()函数的逆操作。
fold = torch.nn.Fold(output_size=(4, 4), kernel_size=(2, 2), stride=2) inputs_restore = fold(patches) print(inputs_restore) print(inputs_restore.size())
nn.functional.fold/unfold
作用:fold和unfold的作用恰好相反,unfold是用一个滑窗来提取图像中的像素值,类似于卷积操作,但是只提取不计算,fold恰好相反将滑窗提取的值返回为一个图像
nn.functional.unfold(input, kernel_size, dilation=1, padding=0, stride=1)
- input: 输入tensor
- kernel_size: 提取时的滑窗大小
- dilation: 滑窗是否有空洞
- padding: 是否对原图进行填充
- stride: 滑窗移动的步长
下面举一个例子直观解释
x = torch.Tensor([[[[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [ 13, 14, 15, 16]]]]) x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2) print(x) print(x.size()) ''' tensor([[[ 1., 3., 9., 11.], [ 2., 4., 10., 12.], [ 5., 7., 13., 15.], [ 6., 8., 14., 16.]]]) torch.Size([1, 4, 4]) '''
执行过程很简单,用一个 2 × 2 的窗在图上滑动,步长为2,第一次覆盖的内容为1256,第二次为3478,以此类推,每次滑窗的结果用一个列向量表示,列数就是滑窗提取的次数。
如果我们要得到每次滑窗的结果,例如第一次提取的结果,用表达式x[:,:,0]
即可
nn.functional.fold(input, output_size, kernel_size, dilation=1, padding=0, stride=1)
- input: 输入tensor
- output_size: 输出图像的大小(必须指定)
- kernel_size: 在图像中填充的形状
- dilation: 滑窗是否有空洞
- padding: 是否对原图进行填充
- stride: 存放窗的tensor时移动的步长
网上很少有讲这个函数的,都说是unfold的逆过程,我们依然用几个例子来对其进行详细的解释
1. 第一个例子
x = torch.Tensor([[[[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [ 13, 14, 15, 16]]]]) x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2) x = F.fold(x, output_size=(4,4), kernel_size=(2,2), padding=0, stride=2) ''' tensor([[[ 1., 3., 9., 11.], [ 2., 4., 10., 12.], [ 5., 7., 13., 15.], [ 6., 8., 14., 16.]]]) torch.Size([1, 4, 4]) tensor([[[[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 9., 10., 11., 12.], [13., 14., 15., 16.]]]]) torch.Size([1, 1, 4, 4]) '''
fold函数是如何执行的呢,他会提取unfold函数的每一列,首先提取1256这一列,然后根据kernel_size的大小将1256重新resize并填到output的第一个位置,如下
[[ 1., 2., 0., 0.], [ 5., 6., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]]
随后提取第二列2.6.10.14,resize为 2 × 2 2\times 2 2×2的形状,根据步长为2添加到output的下一个位置,并以此类推
[[ 1., 2., 3., 4.], [ 5., 6., 7., 8.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]]
注意:output,kernel以及stride必须满足一定的关系(参考文档)
知道原理以后我们可以自由操作上述tensor,但是注意,如果步长等设置不合适的话,最后的结果是有overlap的,下面我们展示两个例子
2. 第二个例子
自由操作tensor
x = torch.Tensor([[[[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [ 13, 14, 15, 16]]]]) x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2) x = F.fold(x, output_size=(4,4), kernel_size=(4,1), padding=0, stride=1) ''' tensor([[[ 1., 3., 9., 11.], [ 2., 4., 10., 12.], [ 5., 7., 13., 15.], [ 6., 8., 14., 16.]]]) torch.Size([1, 4, 4]) # tensor又变回了原来的样子 tensor([[[[ 1., 3., 9., 11.], [ 2., 4., 10., 12.], [ 5., 7., 13., 15.], [ 6., 8., 14., 16.]]]]) torch.Size([1, 1, 4, 4]) '''
overlap的情况
根据上述讲的可以自己推一下
x = torch.Tensor([[[[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [ 13, 14, 15, 16]]]]) x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2) x = F.fold(x, output_size=(3,3), kernel_size=(2,2), padding=0, stride=1) ''' tensor([[[ 1., 3., 9., 11.], [ 2., 4., 10., 12.], [ 5., 7., 13., 15.], [ 6., 8., 14., 16.]]]) torch.Size([1, 4, 4]) tensor([[[[ 1., 5., 4.], [14., 34., 20.], [13., 29., 16.]]]]) torch.Size([1, 1, 3, 3]) '''
3. kernel size小于列向量的情况
上面讲了,fold每次都会对列向量进行提取,之前的例子都是kernel size等于列向量,如果我们的kernel size小于列向量就会出现以下情况
x = torch.Tensor([[[[ 1, 2, 3, 4], [ 5, 6, 7, 8], [ 9, 10, 11, 12], [ 13, 14, 15, 16]]]]) x = F.unfold(x, kernel_size=(2, 2), padding=0, stride=2) x = F.fold(x, output_size=(2,2), kernel_size=(1,1), padding=0, stride=1) ''' tensor([[[ 1., 3., 9., 11.], [ 2., 4., 10., 12.], [ 5., 7., 13., 15.], [ 6., 8., 14., 16.]]]) torch.Size([1, 4, 4]) tensor([[[[ 1., 3.], [ 9., 11.]], [[ 2., 4.], [10., 12.]], [[ 5., 7.], [13., 15.]], [[ 6., 8.], [14., 16.]]]]) torch.Size([1, 4, 2, 2]) '''
解释一下,我们第一次提取的应该是1256,但是由于我们的kernel太小了, 1 × 1 = 1 ,只能提取一个元素,因此就是1,我们的output size是 2 × 2 ,步长为1,所以第一次提取的结果如下
[[ 1, 0], [ 0, 0]]
第二次提取时,就需要移动了,提取的不是列向量中的2,而是横向移动的3,接着放到刚才那个元素后面
[[ 1, 3], [ 0, 0]]
之后的过程以此类推,直到我们提取到11,这时我们的行向量提取完了,但是列向量没有,所以我们从第二列开始重复刚才的过程即可,可以看到最终我们输出向量大小为[1,4,2,2]
,4就是我们提取了4次行向量,两个2就是每次提取的大小(即output size)
最后加一个复杂的具有padding的例子
padding就是在对tensor进行操作之前在tensor四周补0或其他的值。例子中仅对unfold进行padding,如果对fold进行padding也同理
x = torch.Tensor([[[[ 1, 2, 3, 4, 5, 6, 7, 8], [ 9, 10, 11, 12, 13, 14, 15, 16], [ 17, 18, 19, 20, 21, 22, 23, 24], [ 25, 26, 27, 28, 29, 30, 31, 32], [ 33, 34, 35, 36, 37, 38, 39, 40], [ 41, 42, 43, 44, 45, 46, 47, 48], [ 49, 50, 51, 52, 53, 54, 55, 56], [ 57, 58, 59, 60, 61, 62, 63, 64]]]]) x = F.unfold(x, kernel_size=(6,6), padding=1, stride=4) x = F.fold(x, output_size=(12,12), kernel_size=(6,6), padding=0, stride=6) ''' tensor([[[ 0., 0., 0., 28.], [ 0., 0., 25., 29.], [ 0., 0., 26., 30.], [ 0., 0., 27., 31.], [ 0., 0., 28., 32.], [ 0., 0., 29., 0.], [ 0., 4., 0., 36.], [ 1., 5., 33., 37.], [ 2., 6., 34., 38.], [ 3., 7., 35., 39.], [ 4., 8., 36., 40.], [ 5., 0., 37., 0.], [ 0., 12., 0., 44.], [ 9., 13., 41., 45.], [10., 14., 42., 46.], [11., 15., 43., 47.], [12., 16., 44., 48.], [13., 0., 45., 0.], [ 0., 20., 0., 52.], [17., 21., 49., 53.], [18., 22., 50., 54.], [19., 23., 51., 55.], [20., 24., 52., 56.], [21., 0., 53., 0.], [ 0., 28., 0., 60.], [25., 29., 57., 61.], [26., 30., 58., 62.], [27., 31., 59., 63.], [28., 32., 60., 64.], [29., 0., 61., 0.], [ 0., 36., 0., 0.], [33., 37., 0., 0.], [34., 38., 0., 0.], [35., 39., 0., 0.], [36., 40., 0., 0.], [37., 0., 0., 0.]]]) torch.Size([1, 36, 4]) tensor([[[[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [ 0., 1., 2., 3., 4., 5., 4., 5., 6., 7., 8., 0.], [ 0., 9., 10., 11., 12., 13., 12., 13., 14., 15., 16., 0.], [ 0., 17., 18., 19., 20., 21., 20., 21., 22., 23., 24., 0.], [ 0., 25., 26., 27., 28., 29., 28., 29., 30., 31., 32., 0.], [ 0., 33., 34., 35., 36., 37., 36., 37., 38., 39., 40., 0.], [ 0., 25., 26., 27., 28., 29., 28., 29., 30., 31., 32., 0.], [ 0., 33., 34., 35., 36., 37., 36., 37., 38., 39., 40., 0.], [ 0., 41., 42., 43., 44., 45., 44., 45., 46., 47., 48., 0.], [ 0., 49., 50., 51., 52., 53., 52., 53., 54., 55., 56., 0.], [ 0., 57., 58., 59., 60., 61., 60., 61., 62., 63., 64., 0.], [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]]) torch.Size([1, 1, 12, 12]) '''
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。