diff --git a/README.md b/README.md index a68280a..c8de5bc 100644 --- a/README.md +++ b/README.md @@ -51,9 +51,10 @@ python test.py -c checkpoints/fuseformer.pth -v data/DAVIS/JPEGImages/blackswan ## Evaluation You can follow [free-form mask generation scheme](https://github.com/JiahuiYu/generative_inpainting) for synthesizing random masks. -Or just download [our prepared stationary masks](https://drive.google.com/file/d/1lV_EZafayBF0QUM7socbKW7HIxxSaoeU/view?usp=sharing) and unzip it to data folder. +Or just download [our prepared stationary masks](https://drive.google.com/file/d/1wihArvScAFT9hs3KDGSoCEjsEjXpfEOL/view?usp=sharing) and unzip it to data folder. ``` mv random_mask_stationary_w432_h240 data/ +mv random_mask_stationary_youtube_w432_h240 data/ ``` Then you need to download [pre-trained model](https://drive.google.com/file/d/1A-ilDsXZCVhWh2_erApyL7C0jXhaeTjR/view?usp=sharing) for evaluating [VFID](https://github.com/deepmind/kinetics-i3d). @@ -63,7 +64,8 @@ mv i3d_rgb_imagenet.pt checkpoints/ ### Evaluation script ``` -python evaluate.py --model fuseformer --ckpt checkpoints/fuseformer.pth --width 432 --height 240 +python evaluate.py --model fuseformer --ckpt checkpoints/fuseformer.pth --dataset davis --width 432 --height 240 +python evaluate.py --model fuseformer --ckpt checkpoints/fuseformer.pth --dataset youtubevos --width 432 --height 240 ``` ## Citing FuseFormer diff --git a/evaluate.py b/evaluate.py index 5ff0b3b..5c3f502 100644 --- a/evaluate.py +++ b/evaluate.py @@ -31,11 +31,12 @@ from scipy import linalg -parser = argparse.ArgumentParser(description="STTN") +parser = argparse.ArgumentParser(description="FuseFormer") parser.add_argument("-v", "--video", type=str, required=False) parser.add_argument("-m", "--mask", type=str, required=False) parser.add_argument("-c", "--ckpt", type=str, required=True) -parser.add_argument("--model", type=str, default='sttn') +parser.add_argument("--model", type=str, default='fuseformer') +parser.add_argument("--dataset", type=str, default='davis') parser.add_argument("--width", type=int, default=432) parser.add_argument("--height", type=int, default=240) parser.add_argument("--outw", type=int, default=432) @@ -171,17 +172,21 @@ def get_i3d_activations(batched_video, target_endpoint='Logits', flatten=True, g return feat def get_frame_mask_list(args): - #data_root = "./data/YouTubeVOS/" - data_root = "./data/DATASET_DAVIS" - mask_dir = "./data/random_mask_stationary_w432_h240" + if args.dataset == 'davis': + data_root = "./data/DATASET_DAVIS" + mask_dir = "./data/random_mask_stationary_w432_h240" + frame_dir = os.path.join(data_root, "JPEGImages", "480p") + elif args.dataset == 'youtubevos': + data_root = "./data/YouTubeVOS/" + mask_dir = "./data/random_mask_stationary_youtube_w432_h240" + frame_dir = os.path.join(data_root, "test_all_frames", "JPEGImages") + mask_folder = sorted(os.listdir(mask_dir)) mask_list = [os.path.join(mask_dir, name) for name in mask_folder] - - frame_dir = os.path.join(data_root, "JPEGImages", "480p") frame_folder = sorted(os.listdir(frame_dir)) frame_list = [os.path.join(frame_dir, name) for name in frame_folder] - print("[Finish building dataset]") + print("[Finish building dataset {}]".format(args.dataset)) return frame_list, mask_list # sample reference frames from the whole video