From 77aea806f4f68da1c439177148c16edd15a7595e Mon Sep 17 00:00:00 2001 From: ruiliu-ai Date: Wed, 8 Sep 2021 12:05:09 +0800 Subject: [PATCH] initial version --- README.md | 52 +++++ configs/youtube-vos.json | 33 ++++ core/dataset.py | 69 +++++++ core/dist.py | 47 +++++ core/loss.py | 40 ++++ core/spectral_norm.py | 267 ++++++++++++++++++++++++++ core/trainer.py | 291 ++++++++++++++++++++++++++++ core/utils.py | 213 +++++++++++++++++++++ model/fuseformer.py | 402 +++++++++++++++++++++++++++++++++++++++ test.py | 160 ++++++++++++++++ train.py | 75 ++++++++ 11 files changed, 1649 insertions(+) create mode 100644 configs/youtube-vos.json create mode 100644 core/dataset.py create mode 100644 core/dist.py create mode 100644 core/loss.py create mode 100644 core/spectral_norm.py create mode 100644 core/trainer.py create mode 100644 core/utils.py create mode 100644 model/fuseformer.py create mode 100644 test.py create mode 100644 train.py diff --git a/README.md b/README.md index e647991..c2e9672 100644 --- a/README.md +++ b/README.md @@ -4,11 +4,58 @@ By [Rui Liu](https://ruiliu-ai.github.io), Hanming Deng, Yangyi Huang, Xiaoyu Sh This repo is the official Pytorch implementation of [FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting](https://arxiv.org/abs/2109.02974). +<<<<<<< HEAD +## Introduction + + +## Usage + +### Prerequisites +- Python >= 3.6 +- Pytorch >= 1.0 and corresponding torchvision (https://pytorch.org/) + +### Install +- Clone this repo: +``` +git clone https://github.com/ruiliu-ai/FuseFormer.git +``` +- Install other packages: +``` +cd FuseFormer +pip install -r requirements.txt +``` + +## Training + +### Dataset preparation +Download datasets ([YouTube-VOS](https://competitions.codalab.org/competitions/19544) and [DAVIS](https://davischallenge.org/davis2017/code.html)) into the data folder. +``` +mkdir data +``` + +### Training script +``` +python train.py -c configs/youtube-vos.json +``` + +## Test +Download [pre-trained model](https://drive.google.com/file/d/1BuSE42QAAUoQAJawbr5mMRXcqRRKeELc/view?usp=sharing) into checkpoints folder. +``` +mkdir checkpoints +``` + +### Test script +``` +python test.py -c checkpoints/fuseformer.pth -v data/DAVIS/JPEGImages/blackswan -m data/DAVIS/Annotations/blackswan +``` + +======= Coming soon. ## Introduction +>>>>>>> fab4dcbb9e27bc1ca819b1de0006611433f0965c ## Citing FuseFormer If you find FuseFormer useful in your research, please consider citing: ``` @@ -20,4 +67,9 @@ If you find FuseFormer useful in your research, please consider citing: } ``` +<<<<<<< HEAD +## Acknowledement +This code relies heavily on the video inpainting framework from [spatial-temporal transformer net](https://github.com/researchmm/STTN). +======= +>>>>>>> fab4dcbb9e27bc1ca819b1de0006611433f0965c diff --git a/configs/youtube-vos.json b/configs/youtube-vos.json new file mode 100644 index 0000000..92e28df --- /dev/null +++ b/configs/youtube-vos.json @@ -0,0 +1,33 @@ +{ + "seed": 2021, + "save_dir": "checkpoints/", + "data_loader": { + "name": "YouTubeVOS", + "data_root": "./data", + "w": 432, + "h": 240, + "sample_length": 5 + }, + "losses": { + "hole_weight": 1, + "valid_weight": 1, + "adversarial_weight": 0.01, + "GAN_LOSS": "hinge" + }, + "model": { + "net": "fuseformer", + "no_dis": 0 + }, + "trainer": { + "type": "Adam", + "beta1": 0, + "beta2": 0.99, + "lr": 1e-4, + "batch_size": 8, + "num_workers": 2, + "log_freq": 100, + "save_freq": 1e4, + "iterations": 50e4, + "niter": 40e4 + } +} diff --git a/core/dataset.py b/core/dataset.py new file mode 100644 index 0000000..11f6b8a --- /dev/null +++ b/core/dataset.py @@ -0,0 +1,69 @@ +import os +import random +import torch +import numpy as np +import torchvision.transforms.functional as F +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from PIL import Image +from core.utils import create_random_shape_with_random_motion +from core.utils import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, args: dict, split='train'): + self.args = args + self.split = split + self.sample_length = args['sample_length'] + self.size = self.w, self.h = (args['w'], args['h']) + + if args['name'] == 'YouTubeVOS': + vid_lst_prefix = os.path.join(args['data_root'], args['name'], split+'_all_frames/JPEGImages') + vid_lst = os.listdir(vid_lst_prefix) + self.video_names = [os.path.join(vid_lst_prefix, name) for name in vid_lst] + + self._to_tensors = transforms.Compose([ + Stack(), + ToTorchFormatTensor(), ]) + + def __len__(self): + return len(self.video_names) + + def __getitem__(self, index): + try: + item = self.load_item(index) + except: + print('Loading error in video {}'.format(self.video_names[index])) + item = self.load_item(0) + return item + + def load_item(self, index): + video_name = self.video_names[index] + all_frames = [os.path.join(video_name, name) for name in sorted(os.listdir(video_name))] + all_masks = create_random_shape_with_random_motion( + len(all_frames), imageHeight=self.h, imageWidth=self.w) + ref_index = get_ref_index(len(all_frames), self.sample_length) + # read video frames + frames = [] + masks = [] + for idx in ref_index: + img = Image.open(all_frames[idx]).convert('RGB') + img = img.resize(self.size) + frames.append(img) + masks.append(all_masks[idx]) + if self.split == 'train': + frames = GroupRandomHorizontalFlip()(frames) + # To tensors + frame_tensors = self._to_tensors(frames)*2.0 - 1.0 + mask_tensors = self._to_tensors(masks) + return frame_tensors, mask_tensors + + +def get_ref_index(length, sample_length): + if random.uniform(0, 1) > 0.5: + ref_index = random.sample(range(length), sample_length) + ref_index.sort() + else: + pivot = random.randint(0, length-sample_length) + ref_index = [pivot+i for i in range(sample_length)] + return ref_index diff --git a/core/dist.py b/core/dist.py new file mode 100644 index 0000000..a66f6ea --- /dev/null +++ b/core/dist.py @@ -0,0 +1,47 @@ +import os +import torch + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_SIZE') is not None: + return int(os.environ.get('PMI_SIZE') or 1) + elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_RANK') is not None: + return int(os.environ.get('PMI_RANK') or 0) + elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get('MPI_LOCALRANKID') is not None: + return int(os.environ.get('MPI_LOCALRANKID') or 0) + elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] + elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') + else: + return "127.0.0.1" diff --git a/core/loss.py b/core/loss.py new file mode 100644 index 0000000..0386742 --- /dev/null +++ b/core/loss.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + +class AdversarialLoss(nn.Module): + r""" + Adversarial loss + https://arxiv.org/abs/1711.10337 + """ + + def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): + r""" + type = nsgan | lsgan | hinge + """ + super(AdversarialLoss, self).__init__() + self.type = type + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + + if type == 'nsgan': + self.criterion = nn.BCELoss() + elif type == 'lsgan': + self.criterion = nn.MSELoss() + elif type == 'hinge': + self.criterion = nn.ReLU() + + def __call__(self, outputs, is_real, is_disc=None): + if self.type == 'hinge': + if is_disc: + if is_real: + outputs = -outputs + return self.criterion(1 + outputs).mean() + else: + return (-outputs).mean() + else: + labels = (self.real_label if is_real else self.fake_label).expand_as( + outputs) + loss = self.criterion(outputs, labels) + return loss + + diff --git a/core/spectral_norm.py b/core/spectral_norm.py new file mode 100644 index 0000000..632b888 --- /dev/null +++ b/core/spectral_norm.py @@ -0,0 +1,267 @@ +""" +Spectral Normalization from https://arxiv.org/abs/1802.05957 +""" +import torch +from torch.nn.functional import normalize + + +class SpectralNorm(object): + # Invariant before and after each forward call: + # u = normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version = 1 + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError('Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format(n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight): + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute(self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module, do_power_iteration): + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v) + u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone() + v = v.clone() + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module): + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, torch.nn.Parameter(weight.detach())) + + def __call__(self, module, inputs): + setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training)) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError("Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + fn = self.fn + version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None) + if version is None or version < 1: + with torch.no_grad(): + weight_orig = state_dict[prefix + fn.name + '_orig'] + # weight = state_dict.pop(prefix + fn.name) + # sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[prefix + fn.name + '_u'] + # v = fn._solve_v_and_rescale(weight_mat, u, sigma) + # state_dict[prefix + fn.name + '_v'] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata): + if 'spectral_norm' not in local_metadata: + local_metadata['spectral_norm'] = {} + key = self.fn.name + '.version' + if key in local_metadata['spectral_norm']: + raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key)) + local_metadata['spectral_norm'][key] = self.fn._version + + +def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None): + r"""Applies spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance(module, (torch.nn.ConvTranspose1d, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module, name='weight'): + r"""Removes the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("spectral_norm of '{}' not found in {}".format( + name, module)) + + +def use_spectral_norm(module, use_sn=False): + if use_sn: + return spectral_norm(module) + return module \ No newline at end of file diff --git a/core/trainer.py b/core/trainer.py new file mode 100644 index 0000000..b92be82 --- /dev/null +++ b/core/trainer.py @@ -0,0 +1,291 @@ +import os +import glob +import logging +import importlib +import numpy as np +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from tensorboardX import SummaryWriter +import torch.distributed as dist + +from core.dataset import Dataset +from core.loss import AdversarialLoss + + +class Trainer(): + def __init__(self, config): + self.config = config + self.epoch = 0 + self.iteration = 0 + + # setup data set and data loader + self.train_dataset = Dataset(config['data_loader'], split='train') + self.train_sampler = None + self.train_args = config['trainer'] + if config['distributed']: + self.train_sampler = DistributedSampler( + self.train_dataset, + num_replicas=config['world_size'], + rank=config['global_rank']) + self.train_loader = DataLoader( + self.train_dataset, + batch_size=self.train_args['batch_size'] // config['world_size'], + shuffle=(self.train_sampler is None), + num_workers=self.train_args['num_workers'], + sampler=self.train_sampler) + + # set loss functions + self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) + self.adversarial_loss = self.adversarial_loss.to(self.config['device']) + self.l1_loss = nn.L1Loss() + + # setup models including generator and discriminator + net = importlib.import_module('model.'+config['model']['net']) + self.netG = net.InpaintGenerator() + self.netG = self.netG.to(self.config['device']) + if not self.config['model']['no_dis']: + self.netD = net.Discriminator( + in_channels=3, use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') + self.netD = self.netD.to(self.config['device']) + self.optimG = torch.optim.Adam( + self.netG.parameters(), + lr=config['trainer']['lr'], + betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) + if not self.config['model']['no_dis']: + self.optimD = torch.optim.Adam( + self.netD.parameters(), + lr=config['trainer']['lr'], + betas=(self.config['trainer']['beta1'], self.config['trainer']['beta2'])) + self.load() + + if config['distributed']: + self.netG = DDP( + self.netG, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=True) + if not self.config['model']['no_dis']: + self.netD = DDP( + self.netD, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=False) + + # set summary writer + self.dis_writer = None + self.gen_writer = None + self.summary = {} + if self.config['global_rank'] == 0 or (not config['distributed']): + self.dis_writer = SummaryWriter( + os.path.join(config['save_dir'], 'dis')) + self.gen_writer = SummaryWriter( + os.path.join(config['save_dir'], 'gen')) + + # get current learning rate + def get_lr(self): + return self.optimG.param_groups[0]['lr'] + + # learning rate scheduler, step + def adjust_learning_rate(self): + decay = 0.1**(min(self.iteration, + self.config['trainer']['niter']) // self.config['trainer']['niter']) + new_lr = self.config['trainer']['lr'] * decay + if new_lr != self.get_lr(): + for param_group in self.optimG.param_groups: + param_group['lr'] = new_lr + if not self.config['model']['no_dis']: + for param_group in self.optimD.param_groups: + param_group['lr'] = new_lr + + # add summary + def add_summary(self, writer, name, val): + if name not in self.summary: + self.summary[name] = 0 + self.summary[name] += val + if writer is not None and self.iteration % 100 == 0: + writer.add_scalar(name, self.summary[name]/100, self.iteration) + self.summary[name] = 0 + + # load netG and netD + def load(self): + model_path = self.config['save_dir'] + if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): + latest_epoch = open(os.path.join( + model_path, 'latest.ckpt'), 'r').read().splitlines()[-1] + else: + ckpts = [os.path.basename(i).split('.pth')[0] for i in glob.glob( + os.path.join(model_path, '*.pth'))] + ckpts.sort() + latest_epoch = ckpts[-1] if len(ckpts) > 0 else None + if latest_epoch is not None: + gen_path = os.path.join( + model_path, 'gen_{}.pth'.format(str(latest_epoch).zfill(5))) + dis_path = os.path.join( + model_path, 'dis_{}.pth'.format(str(latest_epoch).zfill(5))) + opt_path = os.path.join( + model_path, 'opt_{}.pth'.format(str(latest_epoch).zfill(5))) + if self.config['global_rank'] == 0: + print('Loading model from {}...'.format(gen_path)) + data = torch.load(gen_path, map_location=self.config['device']) + self.netG.load_state_dict(data['netG']) + if not self.config['model']['no_dis']: + data = torch.load(dis_path, map_location=self.config['device']) + self.netD.load_state_dict(data['netD']) + data = torch.load(opt_path, map_location=self.config['device']) + self.optimG.load_state_dict(data['optimG']) + if not self.config['model']['no_dis']: + self.optimD.load_state_dict(data['optimD']) + self.epoch = data['epoch'] + self.iteration = data['iteration'] + else: + if self.config['global_rank'] == 0: + print( + 'Warnning: There is no trained model found. An initialized model will be used.') + + # save parameters every eval_epoch + def save(self, it): + if self.config['global_rank'] == 0: + gen_path = os.path.join( + self.config['save_dir'], 'gen_{}.pth'.format(str(it).zfill(5))) + dis_path = os.path.join( + self.config['save_dir'], 'dis_{}.pth'.format(str(it).zfill(5))) + opt_path = os.path.join( + self.config['save_dir'], 'opt_{}.pth'.format(str(it).zfill(5))) + print('\nsaving model to {} ...'.format(gen_path)) + if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): + netG = self.netG.module + if not self.config['model']['no_dis']: + netD = self.netD.module + else: + netG = self.netG + if not self.config['model']['no_dis']: + netD = self.netD + torch.save({'netG': netG.state_dict()}, gen_path) + if not self.config['model']['no_dis']: + torch.save({'netD': netD.state_dict()}, dis_path) + torch.save({'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict(), + 'optimD': self.optimD.state_dict()}, opt_path) + else: + torch.save({'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict()}, opt_path) + os.system('echo {} > {}'.format(str(it).zfill(5), + os.path.join(self.config['save_dir'], 'latest.ckpt'))) + + # train entry + def train(self): + pbar = range(int(self.train_args['iterations'])) + if self.config['global_rank'] == 0: + pbar = tqdm(pbar, initial=self.iteration, dynamic_ncols=True, smoothing=0.01) + + os.makedirs('logs', exist_ok=True) + logging.basicConfig(level=logging.INFO, + format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', + datefmt='%a, %d %b %Y %H:%M:%S', + filename='logs/{}.log'.format(self.config['save_dir'].split('/')[-1]), + filemode='w') + + while True: + self.epoch += 1 + if self.config['distributed']: + self.train_sampler.set_epoch(self.epoch) + + self._train_epoch(pbar) + if self.iteration > self.train_args['iterations']: + break + print('\nEnd training....') + + # process input and calculate loss every training epoch + def _train_epoch(self, pbar): + device = self.config['device'] + + for frames, masks in self.train_loader: + self.adjust_learning_rate() + self.iteration += 1 + + frames, masks = frames.to(device), masks.to(device) + b, t, c, h, w = frames.size() + masked_frame = (frames * (1 - masks).float()) + pred_img = self.netG(masked_frame) + frames = frames.view(b*t, c, h, w) + masks = masks.view(b*t, 1, h, w) + comp_img = frames*(1.-masks) + masks*pred_img + + gen_loss = 0 + dis_loss = 0 + + if not self.config['model']['no_dis']: + # discriminator adversarial loss + real_vid_feat = self.netD(frames) + fake_vid_feat = self.netD(comp_img.detach()) + dis_real_loss = self.adversarial_loss(real_vid_feat, True, True) + dis_fake_loss = self.adversarial_loss(fake_vid_feat, False, True) + dis_loss += (dis_real_loss + dis_fake_loss) / 2 + self.add_summary( + self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) + self.add_summary( + self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) + self.optimD.zero_grad() + dis_loss.backward() + self.optimD.step() + + # generator adversarial loss + gen_vid_feat = self.netD(comp_img) + gan_loss = self.adversarial_loss(gen_vid_feat, True, False) + gan_loss = gan_loss * self.config['losses']['adversarial_weight'] + gen_loss += gan_loss + self.add_summary( + self.gen_writer, 'loss/gan_loss', gan_loss.item()) + + # generator l1 loss + hole_loss = self.l1_loss(pred_img*masks, frames*masks) + hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight'] + gen_loss += hole_loss + self.add_summary( + self.gen_writer, 'loss/hole_loss', hole_loss.item()) + + valid_loss = self.l1_loss(pred_img*(1-masks), frames*(1-masks)) + valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight'] + gen_loss += valid_loss + self.add_summary( + self.gen_writer, 'loss/valid_loss', valid_loss.item()) + + self.optimG.zero_grad() + gen_loss.backward() + self.optimG.step() + + # console logs + if self.config['global_rank'] == 0: + pbar.update(1) + if not self.config['model']['no_dis']: + pbar.set_description(( + f"d: {dis_loss.item():.3f}; g: {gan_loss.item():.3f};" + f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}") + ) + else: + pbar.set_description(( + f"hole: {hole_loss.item():.3f}; valid: {valid_loss.item():.3f}") + ) + + if self.iteration % self.train_args['log_freq'] == 0: + if not self.config['model']['no_dis']: + logging.info('[Iter {}] d: {:.4f}; g: {:.4f}; hole: {:.4f}; valid: {:.4f}'.format(self.iteration, dis_loss.item(), gan_loss.item(), hole_loss.item(), valid_loss.item())) + else: + logging.info('[Iter {}] hole: {:.4f}; valid: {:.4f}'.format(self.iteration, hole_loss.item(), valid_loss.item())) + # saving models + if self.iteration % self.train_args['save_freq'] == 0: + self.save(int(self.iteration//self.train_args['save_freq'])) + if self.iteration > self.train_args['iterations']: + break + diff --git a/core/utils.py b/core/utils.py new file mode 100644 index 0000000..47b2910 --- /dev/null +++ b/core/utils.py @@ -0,0 +1,213 @@ +import os +import cv2 +import random +import numpy as np +from PIL import Image, ImageOps + +import torch +import matplotlib +import matplotlib.patches as patches +from matplotlib.path import Path +from matplotlib import pyplot as plt +matplotlib.use('agg') + + +# ########################################################################### +# ########################################################################### + + +class GroupRandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + + def __init__(self, is_flow=False): + self.is_flow = is_flow + + def __call__(self, img_group, is_flow=False): + v = random.random() + if v < 0.5: + ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] + if self.is_flow: + for i in range(0, len(ret), 2): + # invert flow pixel values when flipping + ret[i] = ImageOps.invert(ret[i]) + return ret + else: + return img_group + + +class Stack(object): + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_group): + mode = img_group[0].mode + if mode == '1': + img_group = [img.convert('L') for img in img_group] + mode = 'L' + if mode == 'L': + return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) + elif mode == 'RGB': + if self.roll: + return np.stack([np.array(x)[:, :, ::-1] for x in img_group], axis=2) + else: + return np.stack(img_group, axis=2) + else: + raise NotImplementedError(f"Image mode {mode}") + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + + def __init__(self, div=True): + self.div = div + + def __call__(self, pic): + if isinstance(pic, np.ndarray): + # numpy img: [L, C, H, W] + img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() + else: + # handle PIL Image + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255) if self.div else img.float() + return img + + +# ########################################## +# ########################################## + +def create_random_shape_with_random_motion(video_length, imageHeight=240, imageWidth=432): + # get a random shape + height = random.randint(imageHeight//3, imageHeight-1) + width = random.randint(imageWidth//3, imageWidth-1) + edge_num = random.randint(6, 8) + ratio = random.randint(6, 8)/10 + region = get_random_shape( + edge_num=edge_num, ratio=ratio, height=height, width=width) + region_width, region_height = region.size + # get random position + x, y = random.randint( + 0, imageHeight-region_height), random.randint(0, imageWidth-region_width) + velocity = get_random_velocity(max_speed=3) + m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks = [m.convert('L')] + # return fixed masks + if random.uniform(0, 1) > 0.5: + return masks*video_length + # return moving masks + for _ in range(video_length-1): + x, y, velocity = random_move_control_points( + x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3) + m = Image.fromarray( + np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks.append(m.convert('L')) + return masks + + +def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240): + ''' + There is the initial point and 3 points per cubic bezier curve. + Thus, the curve will only pass though n points, which will be the sharp edges. + The other 2 modify the shape of the bezier curve. + edge_num, Number of possibly sharp edges + points_num, number of points in the Path + ratio, (0, 1) magnitude of the perturbation from the unit circle, + ''' + points_num = edge_num*3 + 1 + angles = np.linspace(0, 2*np.pi, points_num) + codes = np.full(points_num, Path.CURVE4) + codes[0] = Path.MOVETO + # Using this instad of Path.CLOSEPOLY avoids an innecessary straight line + verts = np.stack((np.cos(angles), np.sin(angles))).T * \ + (2*ratio*np.random.random(points_num)+1-ratio)[:, None] + verts[-1, :] = verts[0, :] + path = Path(verts, codes) + # draw paths into images + fig = plt.figure() + ax = fig.add_subplot(111) + patch = patches.PathPatch(path, facecolor='black', lw=2) + ax.add_patch(patch) + ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.axis('off') # removes the axis to leave only the shape + fig.canvas.draw() + # convert plt images into numpy images + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,))) + plt.close(fig) + # postprocess + data = cv2.resize(data, (width, height))[:, :, 0] + data = (1 - np.array(data > 0).astype(np.uint8))*255 + corrdinates = np.where(data > 0) + xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max( + corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1]) + region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax)) + return region + + +def random_accelerate(velocity, maxAcceleration, dist='uniform'): + speed, angle = velocity + d_speed, d_angle = maxAcceleration + if dist == 'uniform': + speed += np.random.uniform(-d_speed, d_speed) + angle += np.random.uniform(-d_angle, d_angle) + elif dist == 'guassian': + speed += np.random.normal(0, d_speed / 2) + angle += np.random.normal(0, d_angle / 2) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + return (speed, angle) + + +def get_random_velocity(max_speed=3, dist='uniform'): + if dist == 'uniform': + speed = np.random.uniform(max_speed) + elif dist == 'guassian': + speed = np.abs(np.random.normal(0, max_speed / 2)) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + angle = np.random.uniform(0, 2 * np.pi) + return (speed, angle) + + +def random_move_control_points(X, Y, imageHeight, imageWidth, lineVelocity, region_size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3): + region_width, region_height = region_size + speed, angle = lineVelocity + X += int(speed * np.cos(angle)) + Y += int(speed * np.sin(angle)) + lineVelocity = random_accelerate( + lineVelocity, maxLineAcceleration, dist='guassian') + if ((X > imageHeight - region_height) or (X < 0) or (Y > imageWidth - region_width) or (Y < 0)): + lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian') + new_X = np.clip(X, 0, imageHeight - region_height) + new_Y = np.clip(Y, 0, imageWidth - region_width) + return new_X, new_Y, lineVelocity + + + +# ############################################## +# ############################################## + +if __name__ == '__main__': + + trials = 10 + for _ in range(trials): + video_length = 10 + # The returned masks are either stationary (50%) or moving (50%) + masks = create_random_shape_with_random_motion( + video_length, imageHeight=240, imageWidth=432) + + for m in masks: + cv2.imshow('mask', np.array(m)) + cv2.waitKey(500) + diff --git a/model/fuseformer.py b/model/fuseformer.py new file mode 100644 index 0000000..6f483b7 --- /dev/null +++ b/model/fuseformer.py @@ -0,0 +1,402 @@ +''' Fuseformer for Video Inpainting +''' +import numpy as np +import time +import math +from functools import reduce +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +from core.spectral_norm import spectral_norm as _spectral_norm + + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print('Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + self.group = [1, 2, 4, 8, 1] + self.layers = nn.ModuleList([ + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256 , kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True) + ]) + + def forward(self, x): + bt, c, h, w = x.size() + h, w = h//4, w//4 + out = x + for i, layer in enumerate(self.layers): + if i == 8: + x0 = out + if i > 8 and i % 2 == 0: + g = self.group[(i - 8) // 2] + x = x0.view(bt, g, -1, h, w) + o = out.view(bt, g, -1, h, w) + out = torch.cat([x, o], 2).view(bt, -1, h, w) + out = layer(out) + return out + + +class InpaintGenerator(BaseNetwork): + def __init__(self, init_weights=True): + super(InpaintGenerator, self).__init__() + channel = 256 + hidden = 512 + stack_num = 8 + num_head = 4 + kernel_size = (7, 7) + padding = (3, 3) + stride = (3, 3) + output_size = (60, 108) + blocks = [] + dropout = 0. + t2t_params = {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_size} + n_vecs = 1 + for i, d in enumerate(kernel_size): + n_vecs *= int((output_size[i] + 2 * padding[i] - (d - 1) - 1) / stride[i] + 1) + for _ in range(stack_num): + blocks.append(TransformerBlock(hidden=hidden, num_head=num_head, dropout=dropout, n_vecs=n_vecs, + t2t_params=t2t_params)) + self.transformer = nn.Sequential(*blocks) + self.ss = SoftSplit(channel // 2, hidden, kernel_size, stride, padding, dropout=dropout) + self.add_pos_emb = AddPosEmb(n_vecs, hidden) + self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size, stride, padding) + + self.encoder = Encoder() + + # decoder: decode frames from features + self.decoder = nn.Sequential( + deconv(channel // 2, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1) + ) + + if init_weights: + self.init_weights() + + def forward(self, masked_frames): + # extracting features + b, t, c, h, w = masked_frames.size() + time0 = time.time() + enc_feat = self.encoder(masked_frames.view(b * t, c, h, w)) + _, c, h, w = enc_feat.size() + trans_feat = self.ss(enc_feat, b) + trans_feat = self.add_pos_emb(trans_feat) + trans_feat = self.transformer(trans_feat) + trans_feat = self.sc(trans_feat, t) + enc_feat = enc_feat + trans_feat + output = self.decoder(enc_feat) + output = torch.tanh(output) + return output + + +class deconv(nn.Module): + def __init__(self, input_channel, output_channel, kernel_size=3, padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, output_channel, + kernel_size=kernel_size, stride=1, padding=padding) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2, mode='bilinear', + align_corners=True) + return self.conv(x) + + +# ############################################################################# +# ############################# Transformer ################################## +# ############################################################################# + + +class Attention(nn.Module): + """ + Compute 'Scaled Dot Product Attention + """ + + def __init__(self, p=0.1): + super(Attention, self).__init__() + self.dropout = nn.Dropout(p=p) + + def forward(self, query, key, value, m=None): + scores = torch.matmul(query, key.transpose(-2, -1) + ) / math.sqrt(query.size(-1)) + if m is not None: + scores.masked_fill_(m, -1e9) + p_attn = F.softmax(scores, dim=-1) + p_attn = self.dropout(p_attn) + p_val = torch.matmul(p_attn, value) + return p_val, p_attn + + +class AddPosEmb(nn.Module): + def __init__(self, n, c): + super(AddPosEmb, self).__init__() + self.pos_emb = nn.Parameter(torch.zeros(1, 1, n, c).float().normal_(mean=0, std=0.02), requires_grad=True) + self.num_vecs = n + + def forward(self, x): + b, n, c = x.size() + x = x.view(b, -1, self.num_vecs, c) + x = x + self.pos_emb + x = x.view(b, n, c) + return x + + +class SoftSplit(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding, dropout=0.1): + super(SoftSplit, self).__init__() + self.kernel_size = kernel_size + self.t2t = nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding) + c_in = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(c_in, hidden) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, x, b): + feat = self.t2t(x) + feat = feat.permute(0, 2, 1) + feat = self.embedding(feat) + feat = feat.view(b, -1, feat.size(2)) + feat = self.dropout(feat) + return feat + + +class SoftComp(nn.Module): + def __init__(self, channel, hidden, output_size, kernel_size, stride, padding): + super(SoftComp, self).__init__() + self.relu = nn.LeakyReLU(0.2, inplace=True) + c_out = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(hidden, c_out) + self.t2t = torch.nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding) + h, w = output_size + self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True) + + def forward(self, x, t): + feat = self.embedding(x) + b, n, c = feat.size() + feat = feat.view(b * t, -1, c).permute(0, 2, 1) + feat = self.t2t(feat) + self.bias[None] + return feat + + +class MultiHeadedAttention(nn.Module): + """ + Take in model size and number of heads. + """ + + def __init__(self, d_model, head, p=0.1): + super().__init__() + self.query_embedding = nn.Linear(d_model, d_model) + self.value_embedding = nn.Linear(d_model, d_model) + self.key_embedding = nn.Linear(d_model, d_model) + self.output_linear = nn.Linear(d_model, d_model) + self.attention = Attention(p=p) + self.head = head + + def forward(self, x): + b, n, c = x.size() + c_h = c // self.head + key = self.key_embedding(x) + key = key.view(b, n, self.head, c_h).permute(0, 2, 1, 3) + query = self.query_embedding(x) + query = query.view(b, n, self.head, c_h).permute(0, 2, 1, 3) + value = self.value_embedding(x) + value = value.view(b, n, self.head, c_h).permute(0, 2, 1, 3) + att, _ = self.attention(query, key, value) + att = att.permute(0, 2, 1, 3).contiguous().view(b, n, c) + output = self.output_linear(att) + return output + + +class FeedForward(nn.Module): + def __init__(self, d_model, p=0.1): + super(FeedForward, self).__init__() + # We set d_ff as a default to 2048 + self.conv = nn.Sequential( + nn.Linear(d_model, d_model * 4), + nn.ReLU(inplace=True), + nn.Dropout(p=p), + nn.Linear(d_model * 4, d_model), + nn.Dropout(p=p)) + + def forward(self, x): + x = self.conv(x) + return x + + +class FusionFeedForward(nn.Module): + def __init__(self, d_model, p=0.1, n_vecs=None, t2t_params=None): + super(FusionFeedForward, self).__init__() + # We set d_ff as a default to 1960 + hd = 1960 + self.conv1 = nn.Sequential( + nn.Linear(d_model, hd)) + self.conv2 = nn.Sequential( + nn.ReLU(inplace=True), + nn.Dropout(p=p), + nn.Linear(hd, d_model), + nn.Dropout(p=p)) + assert t2t_params is not None and n_vecs is not None + tp = t2t_params.copy() + self.fold = nn.Fold(**tp) + del tp['output_size'] + self.unfold = nn.Unfold(**tp) + self.n_vecs = n_vecs + + def forward(self, x): + x = self.conv1(x) + b, n, c = x.size() + normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs, 49).permute(0, 2, 1) + x = self.unfold(self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) / self.fold(normalizer)).permute(0, 2, + 1).contiguous().view( + b, n, c) + x = self.conv2(x) + return x + + +class TransformerBlock(nn.Module): + """ + Transformer = MultiHead_Attention + Feed_Forward with sublayer connection + """ + + def __init__(self, hidden=128, num_head=4, dropout=0.1, n_vecs=None, t2t_params=None): + super().__init__() + self.attention = MultiHeadedAttention(d_model=hidden, head=num_head, p=dropout) + self.ffn = FusionFeedForward(hidden, p=dropout, n_vecs=n_vecs, t2t_params=t2t_params) + self.norm1 = nn.LayerNorm(hidden) + self.norm2 = nn.LayerNorm(hidden) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, input): + x = self.norm1(input) + x = input + self.dropout(self.attention(x)) + y = self.norm2(x) + x = x + self.ffn(y) + return x + + +# ###################################################################### +# ###################################################################### + + +class Discriminator(BaseNetwork): + def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 32 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d(in_channels=in_channels, out_channels=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=1, bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 1, nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 2, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm(nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), + padding=(1, 2, 2), bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), + stride=(1, 2, 2), padding=(1, 2, 2)) + ) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape + xs_t = torch.transpose(xs, 0, 1) + xs_t = xs_t.unsqueeze(0) # B, C, T, H, W + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/test.py b/test.py new file mode 100644 index 0000000..794ad2a --- /dev/null +++ b/test.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +import cv2 +from PIL import Image +import numpy as np +import importlib +import os +import argparse +import torch +import torch.nn as nn +from torchvision import transforms + +from core.utils import Stack, ToTorchFormatTensor + +parser = argparse.ArgumentParser(description="FuseFormer") +parser.add_argument("-v", "--video", type=str, required=True) +parser.add_argument("-m", "--mask", type=str, required=True) +parser.add_argument("-c", "--ckpt", type=str, required=True) +parser.add_argument("--model", type=str, default='fuseformer') +parser.add_argument("--width", type=int, default=432) +parser.add_argument("--height", type=int, default=240) +parser.add_argument("--outw", type=int, default=432) +parser.add_argument("--outh", type=int, default=240) +parser.add_argument("--step", type=int, default=10) +parser.add_argument("--num_ref", type=int, default=-1) +parser.add_argument("--neighbor_stride", type=int, default=5) +parser.add_argument("--savefps", type=int, default=24) +parser.add_argument("--use_mp4", action='store_true') +args = parser.parse_args() + + +w, h = args.width, args.height +ref_length = args.step # ref_step +num_ref = args.num_ref +neighbor_stride = args.neighbor_stride +default_fps = args.savefps + +_to_tensors = transforms.Compose([ + Stack(), + ToTorchFormatTensor()]) + + +# sample reference frames from the whole video +def get_ref_index(f, neighbor_ids, length): + ref_index = [] + if num_ref == -1: + for i in range(0, length, ref_length): + if not i in neighbor_ids: + ref_index.append(i) + else: + start_idx = max(0, f - ref_length * (num_ref//2)) + end_idx = min(length, f + ref_length * (num_ref//2)) + for i in range(start_idx, end_idx+1, ref_length): + if not i in neighbor_ids: + if len(ref_index) > num_ref: + #if len(ref_index) >= 5-len(neighbor_ids): + break + ref_index.append(i) + return ref_index + + +# read frame-wise masks +def read_mask(mpath): + masks = [] + mnames = os.listdir(mpath) + mnames.sort() + for m in mnames: + m = Image.open(os.path.join(mpath, m)) + m = m.resize((w, h), Image.NEAREST) + m = np.array(m.convert('L')) + m = np.array(m > 0).astype(np.uint8) + m = cv2.dilate(m, cv2.getStructuringElement( + cv2.MORPH_CROSS, (3, 3)), iterations=4) + masks.append(Image.fromarray(m*255)) + return masks + + +# read frames from video +def read_frame_from_videos(args): + vname = args.video + frames = [] + if args.use_mp4: + vidcap = cv2.VideoCapture(vname) + success, image = vidcap.read() + count = 0 + while success: + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + frames.append(image.resize((w,h))) + success, image = vidcap.read() + count += 1 + else: + lst = os.listdir(vname) + lst.sort() + fr_lst = [vname+'/'+name for name in lst] + for fr in fr_lst: + image = cv2.imread(fr) + image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + frames.append(image.resize((w,h))) + return frames + + +def main_worker(): + # set up models + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = importlib.import_module('model.' + args.model) + model = net.InpaintGenerator().to(device) + model_path = args.ckpt + data = torch.load(args.ckpt, map_location=device) + model.load_state_dict(data) + print('loading from: {}'.format(args.ckpt)) + model.eval() + + # prepare datset, encode all frames into deep space + frames = read_frame_from_videos(args) + video_length = len(frames) + imgs = _to_tensors(frames).unsqueeze(0)*2-1 + frames = [np.array(f).astype(np.uint8) for f in frames] + + masks = read_mask(args.mask) + binary_masks = [np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) for m in masks] + masks = _to_tensors(masks).unsqueeze(0) + imgs, masks = imgs.to(device), masks.to(device) + comp_frames = [None]*video_length + print('loading videos and masks from: {}'.format(args.video)) + + # completing holes by spatial-temporal transformers + for f in range(0, video_length, neighbor_stride): + neighbor_ids = [i for i in range(max(0, f-neighbor_stride), min(video_length, f+neighbor_stride+1))] + ref_ids = get_ref_index(f, neighbor_ids, video_length) + print(f, len(neighbor_ids), len(ref_ids)) + len_temp = len(neighbor_ids) + len(ref_ids) + selected_imgs = imgs[:1, neighbor_ids+ref_ids, :, :, :] + selected_masks = masks[:1, neighbor_ids+ref_ids, :, :, :] + with torch.no_grad(): + masked_imgs = selected_imgs*(1-selected_masks) + pred_img = model(masked_imgs) + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy()*255 + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + img = np.array(pred_img[i]).astype( + np.uint8)*binary_masks[idx] + frames[idx] * (1-binary_masks[idx]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype( + np.float32)*0.5 + img.astype(np.float32)*0.5 + name = args.video.strip().split('/')[-1] + writer = cv2.VideoWriter(f"{name}_result.mp4", cv2.VideoWriter_fourcc(*"mp4v"), default_fps, (args.outw, args.outh)) + for f in range(video_length): + comp = np.array(comp_frames[f]).astype( + np.uint8)*binary_masks[f] + frames[f] * (1-binary_masks[f]) + if w != args.outw: + comp = cv2.resize(comp, (args.outw, args.outh), interpolation=cv2.INTER_LINEAR) + writer.write(cv2.cvtColor(np.array(comp).astype(np.uint8), cv2.COLOR_BGR2RGB)) + writer.release() + print('Finish in {}'.format(f"{name}_result.mp4")) + + +if __name__ == '__main__': + main_worker() diff --git a/train.py b/train.py new file mode 100644 index 0000000..821891f --- /dev/null +++ b/train.py @@ -0,0 +1,75 @@ +import os +import json +import argparse +import datetime +import numpy as np +from shutil import copyfile +import torch +import torch.multiprocessing as mp + +from core.trainer import Trainer +from core.dist import ( + get_world_size, + get_local_rank, + get_global_rank, + get_master_ip, +) + +parser = argparse.ArgumentParser(description='FuseFormer') +parser.add_argument('-c', '--config', default='configs/youtube-vos.json', type=str) +parser.add_argument('-p', '--port', default='23455', type=str) +args = parser.parse_args() + + +def main_worker(rank, config): + if 'local_rank' not in config: + config['local_rank'] = config['global_rank'] = rank + if config['distributed']: + torch.cuda.set_device(int(config['local_rank'])) + torch.distributed.init_process_group(backend='nccl', + init_method=config['init_method'], + world_size=config['world_size'], + rank=config['global_rank'], + group_name='mtorch' + ) + print('using GPU {}-{} for training'.format( + int(config['global_rank']), int(config['local_rank']))) + + config['save_dir'] = os.path.join(config['save_dir'], '{}_{}'.format(config['model']['net'], + os.path.basename(args.config).split('.')[0])) + if torch.cuda.is_available(): + config['device'] = torch.device("cuda:{}".format(config['local_rank'])) + else: + config['device'] = 'cpu' + + if (not config['distributed']) or config['global_rank'] == 0: + os.makedirs(config['save_dir'], exist_ok=True) + config_path = os.path.join( + config['save_dir'], args.config.split('/')[-1]) + if not os.path.isfile(config_path): + copyfile(args.config, config_path) + print('[**] create folder {}'.format(config['save_dir'])) + + trainer = Trainer(config) + trainer.train() + + +if __name__ == "__main__": + + # loading configs + config = json.load(open(args.config)) + + # setting distributed configurations + config['world_size'] = get_world_size() + config['init_method'] = f"tcp://{get_master_ip()}:{args.port}" + config['distributed'] = True if config['world_size'] > 1 else False + + # setup distributed parallel training environments + if get_master_ip() == "127.0.0.1": + # manually launch distributed processes + mp.spawn(main_worker, nprocs=config['world_size'], args=(config,)) + else: + # multiple processes have been launched by openmpi + config['local_rank'] = get_local_rank() + config['global_rank'] = get_global_rank() + main_worker(-1, config)