Skip to content

Commit

Permalink
training loader
Browse files Browse the repository at this point in the history
  • Loading branch information
raywzy committed Mar 28, 2021
1 parent 10f588e commit 57c54d1
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 0 deletions.
16 changes: 16 additions & 0 deletions Global/data/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

class BaseDataLoader():
def __init__(self):
pass

def initialize(self, opt):
self.opt = opt
pass

def load_data():
return None



114 changes: 114 additions & 0 deletions Global/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random

class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()

def name(self):
return 'BaseDataset'

def initialize(self, opt):
pass

def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize

if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256

if w<h:
new_w = opt.loadSize
new_h = opt.loadSize * h // w
else:
new_h=opt.loadSize
new_w = opt.loadSize * w // h

if opt.resize_or_crop=='crop_only':
pass


x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))

flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}

def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
# transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it.
transform_list.append(transforms.Scale(256,method))

if 'crop' in opt.resize_or_crop:
if opt.isTrain:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
else:
if opt.test_random_crop:
transform_list.append(transforms.RandomCrop(opt.fineSize))
else:
transform_list.append(transforms.CenterCrop(opt.fineSize))

## when testing, for ablation study, choose center_crop directly.



if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

transform_list += [transforms.ToTensor()]

if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)

def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)

def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img

def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
41 changes: 41 additions & 0 deletions Global/data/custom_dataset_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data
import random
from data.base_data_loader import BaseDataLoader
from data import online_dataset_for_old_photos as dts_ray_bigfile


def CreateDataset(opt):
dataset = None
if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':
dataset = dts_ray_bigfile.UnPairOldPhotos_SR()
if opt.training_dataset=='mapping':
if opt.random_hole:
dataset = dts_ray_bigfile.PairOldPhotos_with_hole()
else:
dataset = dts_ray_bigfile.PairOldPhotos()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset

class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'

def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
drop_last=True)

def load_data(self):
return self.dataloader

def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
9 changes: 9 additions & 0 deletions Global/data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
62 changes: 62 additions & 0 deletions Global/data/image_folder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import os

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)

return images


def default_loader(path):
return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader

def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img

def __len__(self):
return len(self.imgs)

0 comments on commit 57c54d1

Please sign in to comment.