Skip to content

Commit

Permalink
add youtubevos stationary masks
Browse files Browse the repository at this point in the history
  • Loading branch information
ruiliu-ai committed Oct 30, 2021
1 parent 7f7d099 commit 1b64b70
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ python test.py -c checkpoints/fuseformer.pth -v data/DAVIS/JPEGImages/blackswan
## Evaluation
You can follow [free-form mask generation scheme](https://github.com/JiahuiYu/generative_inpainting) for synthesizing random masks.

Or just download [our prepared stationary masks](https://drive.google.com/file/d/1lV_EZafayBF0QUM7socbKW7HIxxSaoeU/view?usp=sharing) and unzip it to data folder.
Or just download [our prepared stationary masks](https://drive.google.com/file/d/1wihArvScAFT9hs3KDGSoCEjsEjXpfEOL/view?usp=sharing) and unzip it to data folder.
```
mv random_mask_stationary_w432_h240 data/
mv random_mask_stationary_youtube_w432_h240 data/
```

Then you need to download [pre-trained model](https://drive.google.com/file/d/1A-ilDsXZCVhWh2_erApyL7C0jXhaeTjR/view?usp=sharing) for evaluating [VFID](https://github.com/deepmind/kinetics-i3d).
Expand All @@ -63,7 +64,8 @@ mv i3d_rgb_imagenet.pt checkpoints/

### Evaluation script
```
python evaluate.py --model fuseformer --ckpt checkpoints/fuseformer.pth --width 432 --height 240
python evaluate.py --model fuseformer --ckpt checkpoints/fuseformer.pth --dataset davis --width 432 --height 240
python evaluate.py --model fuseformer --ckpt checkpoints/fuseformer.pth --dataset youtubevos --width 432 --height 240
```

## Citing FuseFormer
Expand Down
21 changes: 13 additions & 8 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@
from scipy import linalg


parser = argparse.ArgumentParser(description="STTN")
parser = argparse.ArgumentParser(description="FuseFormer")
parser.add_argument("-v", "--video", type=str, required=False)
parser.add_argument("-m", "--mask", type=str, required=False)
parser.add_argument("-c", "--ckpt", type=str, required=True)
parser.add_argument("--model", type=str, default='sttn')
parser.add_argument("--model", type=str, default='fuseformer')
parser.add_argument("--dataset", type=str, default='davis')
parser.add_argument("--width", type=int, default=432)
parser.add_argument("--height", type=int, default=240)
parser.add_argument("--outw", type=int, default=432)
Expand Down Expand Up @@ -171,17 +172,21 @@ def get_i3d_activations(batched_video, target_endpoint='Logits', flatten=True, g
return feat

def get_frame_mask_list(args):
#data_root = "./data/YouTubeVOS/"
data_root = "./data/DATASET_DAVIS"
mask_dir = "./data/random_mask_stationary_w432_h240"
if args.dataset == 'davis':
data_root = "./data/DATASET_DAVIS"
mask_dir = "./data/random_mask_stationary_w432_h240"
frame_dir = os.path.join(data_root, "JPEGImages", "480p")
elif args.dataset == 'youtubevos':
data_root = "./data/YouTubeVOS/"
mask_dir = "./data/random_mask_stationary_youtube_w432_h240"
frame_dir = os.path.join(data_root, "test_all_frames", "JPEGImages")

mask_folder = sorted(os.listdir(mask_dir))
mask_list = [os.path.join(mask_dir, name) for name in mask_folder]

frame_dir = os.path.join(data_root, "JPEGImages", "480p")
frame_folder = sorted(os.listdir(frame_dir))
frame_list = [os.path.join(frame_dir, name) for name in frame_folder]

print("[Finish building dataset]")
print("[Finish building dataset {}]".format(args.dataset))
return frame_list, mask_list

# sample reference frames from the whole video
Expand Down

0 comments on commit 1b64b70

Please sign in to comment.