Skip to content

Commit

Permalink
initial version
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiliu-ai committed Sep 8, 2021
1 parent fab4dcb commit 77aea80
Show file tree
Hide file tree
Showing 11 changed files with 1,649 additions and 0 deletions.
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,58 @@ By [Rui Liu](https://ruiliu-ai.github.io), Hanming Deng, Yangyi Huang, Xiaoyu Sh

This repo is the official Pytorch implementation of [FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting](https://arxiv.org/abs/2109.02974).

<<<<<<< HEAD
## Introduction
<img src='imgs/pipeline.png' width="900px">

## Usage

### Prerequisites
- Python >= 3.6
- Pytorch >= 1.0 and corresponding torchvision (https://pytorch.org/)

### Install
- Clone this repo:
```
git clone https://github.com/ruiliu-ai/FuseFormer.git
```
- Install other packages:
```
cd FuseFormer
pip install -r requirements.txt
```

## Training

### Dataset preparation
Download datasets ([YouTube-VOS](https://competitions.codalab.org/competitions/19544) and [DAVIS](https://davischallenge.org/davis2017/code.html)) into the data folder.
```
mkdir data
```

### Training script
```
python train.py -c configs/youtube-vos.json
```

## Test
Download [pre-trained model](https://drive.google.com/file/d/1BuSE42QAAUoQAJawbr5mMRXcqRRKeELc/view?usp=sharing) into checkpoints folder.
```
mkdir checkpoints
```

### Test script
```
python test.py -c checkpoints/fuseformer.pth -v data/DAVIS/JPEGImages/blackswan -m data/DAVIS/Annotations/blackswan
```

=======
Coming soon.

## Introduction
<img src='imgs/pipeline.png' width="900px">

>>>>>>> fab4dcbb9e27bc1ca819b1de0006611433f0965c
## Citing FuseFormer
If you find FuseFormer useful in your research, please consider citing:
```
Expand All @@ -20,4 +67,9 @@ If you find FuseFormer useful in your research, please consider citing:
}
```

<<<<<<< HEAD
## Acknowledement
This code relies heavily on the video inpainting framework from [spatial-temporal transformer net](https://github.com/researchmm/STTN).
=======
>>>>>>> fab4dcbb9e27bc1ca819b1de0006611433f0965c
33 changes: 33 additions & 0 deletions configs/youtube-vos.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"seed": 2021,
"save_dir": "checkpoints/",
"data_loader": {
"name": "YouTubeVOS",
"data_root": "./data",
"w": 432,
"h": 240,
"sample_length": 5
},
"losses": {
"hole_weight": 1,
"valid_weight": 1,
"adversarial_weight": 0.01,
"GAN_LOSS": "hinge"
},
"model": {
"net": "fuseformer",
"no_dis": 0
},
"trainer": {
"type": "Adam",
"beta1": 0,
"beta2": 0.99,
"lr": 1e-4,
"batch_size": 8,
"num_workers": 2,
"log_freq": 100,
"save_freq": 1e4,
"iterations": 50e4,
"niter": 40e4
}
}
69 changes: 69 additions & 0 deletions core/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import random
import torch
import numpy as np
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
from core.utils import create_random_shape_with_random_motion
from core.utils import Stack, ToTorchFormatTensor, GroupRandomHorizontalFlip


class Dataset(torch.utils.data.Dataset):
def __init__(self, args: dict, split='train'):
self.args = args
self.split = split
self.sample_length = args['sample_length']
self.size = self.w, self.h = (args['w'], args['h'])

if args['name'] == 'YouTubeVOS':
vid_lst_prefix = os.path.join(args['data_root'], args['name'], split+'_all_frames/JPEGImages')
vid_lst = os.listdir(vid_lst_prefix)
self.video_names = [os.path.join(vid_lst_prefix, name) for name in vid_lst]

self._to_tensors = transforms.Compose([
Stack(),
ToTorchFormatTensor(), ])

def __len__(self):
return len(self.video_names)

def __getitem__(self, index):
try:
item = self.load_item(index)
except:
print('Loading error in video {}'.format(self.video_names[index]))
item = self.load_item(0)
return item

def load_item(self, index):
video_name = self.video_names[index]
all_frames = [os.path.join(video_name, name) for name in sorted(os.listdir(video_name))]
all_masks = create_random_shape_with_random_motion(
len(all_frames), imageHeight=self.h, imageWidth=self.w)
ref_index = get_ref_index(len(all_frames), self.sample_length)
# read video frames
frames = []
masks = []
for idx in ref_index:
img = Image.open(all_frames[idx]).convert('RGB')
img = img.resize(self.size)
frames.append(img)
masks.append(all_masks[idx])
if self.split == 'train':
frames = GroupRandomHorizontalFlip()(frames)
# To tensors
frame_tensors = self._to_tensors(frames)*2.0 - 1.0
mask_tensors = self._to_tensors(masks)
return frame_tensors, mask_tensors


def get_ref_index(length, sample_length):
if random.uniform(0, 1) > 0.5:
ref_index = random.sample(range(length), sample_length)
ref_index.sort()
else:
pivot = random.randint(0, length-sample_length)
ref_index = [pivot+i for i in range(sample_length)]
return ref_index
47 changes: 47 additions & 0 deletions core/dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import torch


def get_world_size():
"""Find OMPI world size without calling mpi functions
:rtype: int
"""
if os.environ.get('PMI_SIZE') is not None:
return int(os.environ.get('PMI_SIZE') or 1)
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
else:
return torch.cuda.device_count()


def get_global_rank():
"""Find OMPI world rank without calling mpi functions
:rtype: int
"""
if os.environ.get('PMI_RANK') is not None:
return int(os.environ.get('PMI_RANK') or 0)
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
else:
return 0


def get_local_rank():
"""Find OMPI local rank without calling mpi functions
:rtype: int
"""
if os.environ.get('MPI_LOCALRANKID') is not None:
return int(os.environ.get('MPI_LOCALRANKID') or 0)
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
else:
return 0


def get_master_ip():
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
else:
return "127.0.0.1"
40 changes: 40 additions & 0 deletions core/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn as nn

class AdversarialLoss(nn.Module):
r"""
Adversarial loss
https://arxiv.org/abs/1711.10337
"""

def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
r"""
type = nsgan | lsgan | hinge
"""
super(AdversarialLoss, self).__init__()
self.type = type
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))

if type == 'nsgan':
self.criterion = nn.BCELoss()
elif type == 'lsgan':
self.criterion = nn.MSELoss()
elif type == 'hinge':
self.criterion = nn.ReLU()

def __call__(self, outputs, is_real, is_disc=None):
if self.type == 'hinge':
if is_disc:
if is_real:
outputs = -outputs
return self.criterion(1 + outputs).mean()
else:
return (-outputs).mean()
else:
labels = (self.real_label if is_real else self.fake_label).expand_as(
outputs)
loss = self.criterion(outputs, labels)
return loss


Loading

0 comments on commit 77aea80

Please sign in to comment.