forked from microsoft/Bringing-Old-Photos-Back-to-Life
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
242 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
class BaseDataLoader(): | ||
def __init__(self): | ||
pass | ||
|
||
def initialize(self, opt): | ||
self.opt = opt | ||
pass | ||
|
||
def load_data(): | ||
return None | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch.utils.data as data | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
import numpy as np | ||
import random | ||
|
||
class BaseDataset(data.Dataset): | ||
def __init__(self): | ||
super(BaseDataset, self).__init__() | ||
|
||
def name(self): | ||
return 'BaseDataset' | ||
|
||
def initialize(self, opt): | ||
pass | ||
|
||
def get_params(opt, size): | ||
w, h = size | ||
new_h = h | ||
new_w = w | ||
if opt.resize_or_crop == 'resize_and_crop': | ||
new_h = new_w = opt.loadSize | ||
|
||
if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256 | ||
|
||
if w<h: | ||
new_w = opt.loadSize | ||
new_h = opt.loadSize * h // w | ||
else: | ||
new_h=opt.loadSize | ||
new_w = opt.loadSize * w // h | ||
|
||
if opt.resize_or_crop=='crop_only': | ||
pass | ||
|
||
|
||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) | ||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) | ||
|
||
flip = random.random() > 0.5 | ||
return {'crop_pos': (x, y), 'flip': flip} | ||
|
||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True): | ||
transform_list = [] | ||
if 'resize' in opt.resize_or_crop: | ||
osize = [opt.loadSize, opt.loadSize] | ||
transform_list.append(transforms.Scale(osize, method)) | ||
elif 'scale_width' in opt.resize_or_crop: | ||
# transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it. | ||
transform_list.append(transforms.Scale(256,method)) | ||
|
||
if 'crop' in opt.resize_or_crop: | ||
if opt.isTrain: | ||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) | ||
else: | ||
if opt.test_random_crop: | ||
transform_list.append(transforms.RandomCrop(opt.fineSize)) | ||
else: | ||
transform_list.append(transforms.CenterCrop(opt.fineSize)) | ||
|
||
## when testing, for ablation study, choose center_crop directly. | ||
|
||
|
||
|
||
if opt.resize_or_crop == 'none': | ||
base = float(2 ** opt.n_downsample_global) | ||
if opt.netG == 'local': | ||
base *= (2 ** opt.n_local_enhancers) | ||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) | ||
|
||
if opt.isTrain and not opt.no_flip: | ||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) | ||
|
||
transform_list += [transforms.ToTensor()] | ||
|
||
if normalize: | ||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), | ||
(0.5, 0.5, 0.5))] | ||
return transforms.Compose(transform_list) | ||
|
||
def normalize(): | ||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
|
||
def __make_power_2(img, base, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
h = int(round(oh / base) * base) | ||
w = int(round(ow / base) * base) | ||
if (h == oh) and (w == ow): | ||
return img | ||
return img.resize((w, h), method) | ||
|
||
def __scale_width(img, target_width, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
if (ow == target_width): | ||
return img | ||
w = target_width | ||
h = int(target_width * oh / ow) | ||
return img.resize((w, h), method) | ||
|
||
def __crop(img, pos, size): | ||
ow, oh = img.size | ||
x1, y1 = pos | ||
tw = th = size | ||
if (ow > tw or oh > th): | ||
return img.crop((x1, y1, x1 + tw, y1 + th)) | ||
return img | ||
|
||
def __flip(img, flip): | ||
if flip: | ||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch.utils.data | ||
import random | ||
from data.base_data_loader import BaseDataLoader | ||
from data import online_dataset_for_old_photos as dts_ray_bigfile | ||
|
||
|
||
def CreateDataset(opt): | ||
dataset = None | ||
if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B': | ||
dataset = dts_ray_bigfile.UnPairOldPhotos_SR() | ||
if opt.training_dataset=='mapping': | ||
if opt.random_hole: | ||
dataset = dts_ray_bigfile.PairOldPhotos_with_hole() | ||
else: | ||
dataset = dts_ray_bigfile.PairOldPhotos() | ||
print("dataset [%s] was created" % (dataset.name())) | ||
dataset.initialize(opt) | ||
return dataset | ||
|
||
class CustomDatasetDataLoader(BaseDataLoader): | ||
def name(self): | ||
return 'CustomDatasetDataLoader' | ||
|
||
def initialize(self, opt): | ||
BaseDataLoader.initialize(self, opt) | ||
self.dataset = CreateDataset(opt) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batchSize, | ||
shuffle=not opt.serial_batches, | ||
num_workers=int(opt.nThreads), | ||
drop_last=True) | ||
|
||
def load_data(self): | ||
return self.dataloader | ||
|
||
def __len__(self): | ||
return min(len(self.dataset), self.opt.max_dataset_size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
def CreateDataLoader(opt): | ||
from data.custom_dataset_data_loader import CustomDatasetDataLoader | ||
data_loader = CustomDatasetDataLoader() | ||
print(data_loader.name()) | ||
data_loader.initialize(opt) | ||
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import torch.utils.data as data | ||
from PIL import Image | ||
import os | ||
|
||
IMG_EXTENSIONS = [ | ||
'.jpg', '.JPG', '.jpeg', '.JPEG', | ||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' | ||
] | ||
|
||
|
||
def is_image_file(filename): | ||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) | ||
|
||
|
||
def make_dataset(dir): | ||
images = [] | ||
assert os.path.isdir(dir), '%s is not a valid directory' % dir | ||
|
||
for root, _, fnames in sorted(os.walk(dir)): | ||
for fname in fnames: | ||
if is_image_file(fname): | ||
path = os.path.join(root, fname) | ||
images.append(path) | ||
|
||
return images | ||
|
||
|
||
def default_loader(path): | ||
return Image.open(path).convert('RGB') | ||
|
||
|
||
class ImageFolder(data.Dataset): | ||
|
||
def __init__(self, root, transform=None, return_paths=False, | ||
loader=default_loader): | ||
imgs = make_dataset(root) | ||
if len(imgs) == 0: | ||
raise(RuntimeError("Found 0 images in: " + root + "\n" | ||
"Supported image extensions are: " + | ||
",".join(IMG_EXTENSIONS))) | ||
|
||
self.root = root | ||
self.imgs = imgs | ||
self.transform = transform | ||
self.return_paths = return_paths | ||
self.loader = loader | ||
|
||
def __getitem__(self, index): | ||
path = self.imgs[index] | ||
img = self.loader(path) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.return_paths: | ||
return img, path | ||
else: | ||
return img | ||
|
||
def __len__(self): | ||
return len(self.imgs) |