-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
1,649 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
{ | ||
"seed": 2021, | ||
"save_dir": "checkpoints/", | ||
"data_loader": { | ||
"name": "YouTubeVOS", | ||
"data_root": "./data", | ||
"w": 432, | ||
"h": 240, | ||
"sample_length": 5 | ||
}, | ||
"losses": { | ||
"hole_weight": 1, | ||
"valid_weight": 1, | ||
"adversarial_weight": 0.01, | ||
"GAN_LOSS": "hinge" | ||
}, | ||
"model": { | ||
"net": "fuseformer", | ||
"no_dis": 0 | ||
}, | ||
"trainer": { | ||
"type": "Adam", | ||
"beta1": 0, | ||
"beta2": 0.99, | ||
"lr": 1e-4, | ||
"batch_size": 8, | ||
"num_workers": 2, | ||
"log_freq": 100, | ||
"save_freq": 1e4, | ||
"iterations": 50e4, | ||
"niter": 40e4 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
import random | ||
import torch | ||
import numpy as np | ||
import torchvision.transforms.functional as F | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import DataLoader | ||
from PIL import Image | ||
from core.utils import create_random_shape_with_random_motion | ||
from core.utils import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip | ||
|
||
|
||
class Dataset(torch.utils.data.Dataset): | ||
def __init__(self, args: dict, split='train'): | ||
self.args = args | ||
self.split = split | ||
self.sample_length = args['sample_length'] | ||
self.size = self.w, self.h = (args['w'], args['h']) | ||
|
||
if args['name'] == 'YouTubeVOS': | ||
vid_lst_prefix = os.path.join(args['data_root'], args['name'], split+'_all_frames/JPEGImages') | ||
vid_lst = os.listdir(vid_lst_prefix) | ||
self.video_names = [os.path.join(vid_lst_prefix, name) for name in vid_lst] | ||
|
||
self._to_tensors = transforms.Compose([ | ||
Stack(), | ||
ToTorchFormatTensor(), ]) | ||
|
||
def __len__(self): | ||
return len(self.video_names) | ||
|
||
def __getitem__(self, index): | ||
try: | ||
item = self.load_item(index) | ||
except: | ||
print('Loading error in video {}'.format(self.video_names[index])) | ||
item = self.load_item(0) | ||
return item | ||
|
||
def load_item(self, index): | ||
video_name = self.video_names[index] | ||
all_frames = [os.path.join(video_name, name) for name in sorted(os.listdir(video_name))] | ||
all_masks = create_random_shape_with_random_motion( | ||
len(all_frames), imageHeight=self.h, imageWidth=self.w) | ||
ref_index = get_ref_index(len(all_frames), self.sample_length) | ||
# read video frames | ||
frames = [] | ||
masks = [] | ||
for idx in ref_index: | ||
img = Image.open(all_frames[idx]).convert('RGB') | ||
img = img.resize(self.size) | ||
frames.append(img) | ||
masks.append(all_masks[idx]) | ||
if self.split == 'train': | ||
frames = GroupRandomHorizontalFlip()(frames) | ||
# To tensors | ||
frame_tensors = self._to_tensors(frames)*2.0 - 1.0 | ||
mask_tensors = self._to_tensors(masks) | ||
return frame_tensors, mask_tensors | ||
|
||
|
||
def get_ref_index(length, sample_length): | ||
if random.uniform(0, 1) > 0.5: | ||
ref_index = random.sample(range(length), sample_length) | ||
ref_index.sort() | ||
else: | ||
pivot = random.randint(0, length-sample_length) | ||
ref_index = [pivot+i for i in range(sample_length)] | ||
return ref_index |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import os | ||
import torch | ||
|
||
|
||
def get_world_size(): | ||
"""Find OMPI world size without calling mpi functions | ||
:rtype: int | ||
""" | ||
if os.environ.get('PMI_SIZE') is not None: | ||
return int(os.environ.get('PMI_SIZE') or 1) | ||
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: | ||
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) | ||
else: | ||
return torch.cuda.device_count() | ||
|
||
|
||
def get_global_rank(): | ||
"""Find OMPI world rank without calling mpi functions | ||
:rtype: int | ||
""" | ||
if os.environ.get('PMI_RANK') is not None: | ||
return int(os.environ.get('PMI_RANK') or 0) | ||
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: | ||
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) | ||
else: | ||
return 0 | ||
|
||
|
||
def get_local_rank(): | ||
"""Find OMPI local rank without calling mpi functions | ||
:rtype: int | ||
""" | ||
if os.environ.get('MPI_LOCALRANKID') is not None: | ||
return int(os.environ.get('MPI_LOCALRANKID') or 0) | ||
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: | ||
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) | ||
else: | ||
return 0 | ||
|
||
|
||
def get_master_ip(): | ||
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: | ||
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] | ||
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: | ||
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') | ||
else: | ||
return "127.0.0.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
class AdversarialLoss(nn.Module): | ||
r""" | ||
Adversarial loss | ||
https://arxiv.org/abs/1711.10337 | ||
""" | ||
|
||
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): | ||
r""" | ||
type = nsgan | lsgan | hinge | ||
""" | ||
super(AdversarialLoss, self).__init__() | ||
self.type = type | ||
self.register_buffer('real_label', torch.tensor(target_real_label)) | ||
self.register_buffer('fake_label', torch.tensor(target_fake_label)) | ||
|
||
if type == 'nsgan': | ||
self.criterion = nn.BCELoss() | ||
elif type == 'lsgan': | ||
self.criterion = nn.MSELoss() | ||
elif type == 'hinge': | ||
self.criterion = nn.ReLU() | ||
|
||
def __call__(self, outputs, is_real, is_disc=None): | ||
if self.type == 'hinge': | ||
if is_disc: | ||
if is_real: | ||
outputs = -outputs | ||
return self.criterion(1 + outputs).mean() | ||
else: | ||
return (-outputs).mean() | ||
else: | ||
labels = (self.real_label if is_real else self.fake_label).expand_as( | ||
outputs) | ||
loss = self.criterion(outputs, labels) | ||
return loss | ||
|
||
|
Oops, something went wrong.