import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from skimage import io
import PIL
import os
import mimetypes
import torchvision.transforms as transforms
import glob
from skimage.io import imread
from natsort import natsorted
import re
import numba
from fastai2.vision.all import *
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
pytorch unfold & fold
Using pytorch unfold and fold to construct the sliding window manually
from pdb import set_trace
tensor.unfold
= torch.arange(48).view(3, 4, 4)
x
x.shape# x.view(8,8)
x
print('test')
0, 2, 1).shape
x.unfold(0, 2, 1)
x.unfold(
print('exp1')
0, 3, 3).shape
x.unfold(0, 3, 3)
x.unfold(
print('exp2')
0, 3, 3).unfold(1, 2, 2).shape
x.unfold(0, 3, 3).unfold(1, 2, 2)
x.unfold(
print('exp3')
0, 3, 3).unfold(1, 2, 2).unfold(2, 2, 2).shape
x.unfold(0, 3, 3).unfold(1, 2, 2).unfold(2, 2, 2) x.unfold(
torch.Size([3, 4, 4])
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[32, 33, 34, 35],
[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]]])
test
torch.Size([2, 4, 4, 2])
tensor([[[[ 0, 16],
[ 1, 17],
[ 2, 18],
[ 3, 19]],
[[ 4, 20],
[ 5, 21],
[ 6, 22],
[ 7, 23]],
[[ 8, 24],
[ 9, 25],
[10, 26],
[11, 27]],
[[12, 28],
[13, 29],
[14, 30],
[15, 31]]],
[[[16, 32],
[17, 33],
[18, 34],
[19, 35]],
[[20, 36],
[21, 37],
[22, 38],
[23, 39]],
[[24, 40],
[25, 41],
[26, 42],
[27, 43]],
[[28, 44],
[29, 45],
[30, 46],
[31, 47]]]])
exp1
torch.Size([1, 4, 4, 3])
tensor([[[[ 0, 16, 32],
[ 1, 17, 33],
[ 2, 18, 34],
[ 3, 19, 35]],
[[ 4, 20, 36],
[ 5, 21, 37],
[ 6, 22, 38],
[ 7, 23, 39]],
[[ 8, 24, 40],
[ 9, 25, 41],
[10, 26, 42],
[11, 27, 43]],
[[12, 28, 44],
[13, 29, 45],
[14, 30, 46],
[15, 31, 47]]]])
exp2
torch.Size([1, 2, 4, 3, 2])
tensor([[[[[ 0, 4],
[16, 20],
[32, 36]],
[[ 1, 5],
[17, 21],
[33, 37]],
[[ 2, 6],
[18, 22],
[34, 38]],
[[ 3, 7],
[19, 23],
[35, 39]]],
[[[ 8, 12],
[24, 28],
[40, 44]],
[[ 9, 13],
[25, 29],
[41, 45]],
[[10, 14],
[26, 30],
[42, 46]],
[[11, 15],
[27, 31],
[43, 47]]]]])
exp3
torch.Size([1, 2, 2, 3, 2, 2])
tensor([[[[[[ 0, 1],
[ 4, 5]],
[[16, 17],
[20, 21]],
[[32, 33],
[36, 37]]],
[[[ 2, 3],
[ 6, 7]],
[[18, 19],
[22, 23]],
[[34, 35],
[38, 39]]]],
[[[[ 8, 9],
[12, 13]],
[[24, 25],
[28, 29]],
[[40, 41],
[44, 45]]],
[[[10, 11],
[14, 15]],
[[26, 27],
[30, 31]],
[[42, 43],
[46, 47]]]]]])
= torch.randint(0, 10, (3, 5176, 3793))
temp
temp.shape= temp.unfold(0, 3, 3)
patches
patches.shape0, 3, 3), temp.unfold(0, 3, 4))
test_eq(temp.unfold(
= patches.unfold(1, 128, 128)
patches
patches.shape
= patches.unfold(2, 128, 128)
patches # test_eq(temp.unfold(0,3,3),temp.unfold(0,3,66))
patches.shape
torch.Size([3, 5176, 3793])
torch.Size([1, 5176, 3793, 3])
torch.Size([1, 40, 3793, 3, 128])
torch.Size([1, 40, 29, 3, 128, 128])
5176-128)/128)+1 math.floor((
40
3793-128)/128)+1 math.floor((
29
tensor.unfold.rules
important
eg.
(a,b) = x.shape x.unfold(c,d,e) where d is the size and e is the step
from here we can see it: the shape value at dimension c after unfold method is that:
eg. at a ’s dimension:
(math.floor(a-d)/e +1,b,d)
BTW: the last one is to append the size value in the unfold method
torch.nn.unfold and fold
unfold https://pytorch.org/docs/master/generated/torch.nn.Unfold.html#torch.nn.Unfold
= torch.randn(1,3,10,12)
inp = torch.randn(2,3,4,5)
w = torch.nn.functional.unfold(inp,(4,5))
inp_unf inp_unf.shape
torch.Size([1, 60, 56])
fold https://pytorch.org/docs/master/generated/torch.nn.Fold.html?highlight=fold#torch.nn.Fold
experiment on an Image
# !wget https://eoimages.gsfc.nasa.gov/images/imagerecords/88000/88094/niobrara_photo_lrg.jpg
=512
patch_size=patch_size
stride= transforms.ToTensor()
pil2tensor
file=Path('niobrara_photo_lrg.jpg')
=file.stem
filename= Image.open(file)
im1 print(im1.shape)
# im1.resize(5120,5120)
= im1.resize((1500,1500),Image.BILINEAR)
im1
im1
= pil2tensor(im1)
rgb_image rgb_image.shape
(1536, 2048)
torch.Size([3, 1500, 1500])
type() rgb_image.data.
'torch.FloatTensor'
tensor.unfold
= rgb_image.data.unfold(0, 3, 3).unfold(1, patch_size, stride).unfold(2, patch_size, stride)
patches print(patches.shape)
torch.Size([1, 2, 2, 3, 512, 512])
https://pytorch.org/docs/master/generated/torch.split.html
= list(patches.shape) a
a0,a[1]))
torch.from_numpy(np.arange(0,a[1])),:,:,:,:].shape
patches[:,torch.from_numpy(np.arange(= patches[:,torch.from_numpy(np.arange(0,a[1])),:,:,:,:].split(1, dim=1)
x = patches.split(1, dim=1)
x # x = patches.split(1, dim=2)
len(x)
0].shape
x[1].shape x[
[1, 2, 2, 3, 512, 512]
tensor([0, 1])
torch.Size([1, 2, 2, 3, 512, 512])
2
torch.Size([1, 1, 2, 3, 512, 512])
torch.Size([1, 1, 2, 3, 512, 512])
= ToPILImage()
to_pil 1500/512) math.floor(
2
6000/512
11.71875
= patches[:,torch.from_numpy(np.arange(0,a[1])),:,:,:,:].split(1, dim=1)
x for i in list(np.arange(a[1])):
= x[i][:,:,torch.from_numpy(np.arange(0,a[2])),:,:,:].split(1, dim=2)
y for j in list(np.arange(a[2])):
= to_pil(y[j].squeeze(0).squeeze(0).squeeze(0))
img
img# set_trace()
# save_image(y[j], filename+'-'+str(i)+'-'+str(j)+'.png')
nn.functional.unfold and fold to extract and reconstruct
https://discuss.pytorch.org/t/seemlessly-blending-tensors-together/65235/9
def split_tensor(tensor, tile_size=256):
= torch.ones_like(tensor)
mask # use torch.nn.Unfold
= tile_size//2
stride = nn.Unfold(kernel_size=(tile_size, tile_size), stride=stride)
unfold # Apply to mask and original image
= unfold(mask)
mask_p = unfold(tensor)
patches
= patches.reshape(3, tile_size, tile_size, -1).permute(3, 0, 1, 2)
patches if tensor.is_cuda:
= torch.zeros(patches.size(), device=tensor.get_device())
patches_base else:
= torch.zeros(patches.size())
patches_base
= []
tiles for t in range(patches.size(0)):
tiles.append(patches[[t], :, :, :])return tiles, mask_p, patches_base, (tensor.size(2), tensor.size(3))
def rebuild_tensor(tensor_list, mask_t, base_tensor, t_size, tile_size=256):
= tile_size//2
stride # base_tensor here is used as a container
for t, tile in enumerate(tensor_list):
print(tile.size())
= tile
base_tensor[[t], :, :]
= base_tensor.permute(1, 2, 3, 0).reshape(3*tile_size*tile_size, base_tensor.size(0)).unsqueeze(0)
base_tensor = nn.Fold(output_size=(t_size[0], t_size[1]), kernel_size=(tile_size, tile_size), stride=stride)
fold # https://discuss.pytorch.org/t/seemlessly-blending-tensors-together/65235/2?u=bowenroom
= fold(base_tensor)/fold(mask_t)
output_tensor # output_tensor = fold(base_tensor)
return output_tensor
# %%time
= 'test_image.jpg'
test_image =1024
image_size= transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
Loader = Loader(Image.open(file).convert('RGB')).unsqueeze(0).cuda()
input_tensor
# Split image into overlapping tiles
= split_tensor(input_tensor, 660)
tile_tensors, mask_t, base_tensor, t_size
# Put tiles back together
= rebuild_tensor(tile_tensors, mask_t, base_tensor, t_size, 660)
output_tensor
# Save Output
= transforms.ToPILImage()
Image2PIL print(f'the whole length of the patches is {len(tile_tensors)}')
# show small patches
for i in range(len(tile_tensors)):
print(f'the current is {i}')
0))
Image2PIL(tile_tensors[i].cpu().squeeze(print('the reconstruct image')
0))
Image2PIL(output_tensor.cpu().squeeze(# Image2PIL(output_tensor.cpu().squeeze(0)).save('output_image.png')
torch.Size([1, 3, 660, 660])
torch.Size([1, 3, 660, 660])
torch.Size([1, 3, 660, 660])
torch.Size([1, 3, 660, 660])
torch.Size([1, 3, 660, 660])
torch.Size([1, 3, 660, 660])
the whole length of the patches is 6
the current is 0
the current is 1
the current is 2
the current is 3
the current is 4
the current is 5
the reconstruct image
6000/512
11.71875
fastai2.PILImage and PIL.image
len(tile_tensors)
0].size() tile_tensors[
6
torch.Size([1, 3, 660, 660])
0].squeeze(0) tile_tensors[
tensor([[[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[0.5647, 0.5451, 0.4902, ..., 0.7490, 0.7569, 0.7412],
[0.5490, 0.5529, 0.4863, ..., 0.7412, 0.7490, 0.7412],
[0.5608, 0.5686, 0.5059, ..., 0.7412, 0.7569, 0.7529]],
[[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[0.6667, 0.6549, 0.6196, ..., 0.7647, 0.7686, 0.7490],
[0.6431, 0.6510, 0.6039, ..., 0.7451, 0.7529, 0.7412],
[0.6471, 0.6549, 0.6078, ..., 0.7333, 0.7451, 0.7412]],
[[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
...,
[0.3765, 0.3647, 0.3176, ..., 0.8000, 0.8039, 0.7882],
[0.3412, 0.3569, 0.3098, ..., 0.7843, 0.7882, 0.7804],
[0.3373, 0.3608, 0.3255, ..., 0.7765, 0.7882, 0.7843]]],
device='cuda:0')
= PILImage(Image2PIL(tile_tensors[0].cpu().squeeze(0)))
temp
temp temp.shape