Skip to content

Commit

Permalink
Make things work internally and recover performance
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 480146309
Change-Id: I981b8b8d2ebc2646f83505457e3eb86afd757389
  • Loading branch information
cdoersch authored and copybara-github committed Oct 10, 2022
1 parent 8a7e4d0 commit 4438dde
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 118 deletions.
167 changes: 104 additions & 63 deletions evaluation_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

import csv
import functools
import glob
import os
from os import path
import pickle
import random
from typing import Iterable, Optional, Mapping, Union
from typing import Iterable, Mapping, Union

from absl import logging
import chex
Expand All @@ -41,17 +41,18 @@
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]


def sample_and_pad(
# TODO(doersch): can we remove the jax dependency?
def sample_queries_strided(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
query_stride: int = 5,
num_frames: Optional[int] = None,
) -> Mapping[str, chex.Array]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, sample queries.
Optionally, pad the sequences by replicating the final frame.
Given a set of frames and tracks with no query points, sample queries
strided every query_stride frames, ignoring points that are not visible
at the selected frames.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
Expand All @@ -62,27 +63,21 @@ def sample_and_pad(
-1 and 1.
query_stride: When sampling query points, search for un-occluded points
every query_stride frames and convert each one into a query.
num_frames: If specified, pad the videos to this length using by duplicating
last frame, including the positions and occlusions.
Returns:
A dict with the keys:
video: Video tensor of shape [1, num_frames, height, width, 3]
video: Video tensor of shape [1, n_frames, height, width, 3].
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1]
each point is [t, y, x] scaled to the range [-1, 1].
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1]
each point is [x, y] scaled to the range [-1, 1].
trackgroup: Index of the original track that each query point was
sampled from. This is useful for visualization
pad_extra_frames: the number of pad frames that were added to reach
num_frames.
sampled from. This is useful for visualization.
"""
tracks = []
occs = []
queries = []
trackgroups = []
if num_frames is None:
num_frames = target_occluded.shape[1]
total = 0
trackgroup = np.arange(target_occluded.shape[0])
for i in range(0, target_occluded.shape[1], query_stride):
Expand All @@ -100,24 +95,63 @@ def sample_and_pad(
trackgroups.append(trackgroup[mask])
total += np.array(jnp.sum(target_occluded[:, i] == 0))

def frame_pad(x):
pads = [(0, 0)] * x.ndim
pads[1] = (0, num_frames - x.shape[1]) # pylint: disable=cell-var-from-loop
return jnp.pad(x, pads, mode='edge')

return {
'video':
frame_pad(frames[jnp.newaxis, ...]),
frames[jnp.newaxis, ...],
'query_points':
jnp.concatenate(queries, axis=0)[jnp.newaxis, :, ...],
jnp.concatenate(queries, axis=0)[jnp.newaxis, ...],
'target_points':
frame_pad(jnp.concatenate(tracks, axis=0))[jnp.newaxis, ...],
jnp.concatenate(tracks, axis=0)[jnp.newaxis, ...],
'occluded':
frame_pad(jnp.concatenate(occs, axis=0))[jnp.newaxis, ...],
jnp.concatenate(occs, axis=0)[jnp.newaxis, ...],
'trackgroup':
jnp.concatenate(trackgroups, axis=0)[jnp.newaxis, :, ...],
'pad_extra_frames':
num_frames - target_occluded.shape[1],
jnp.concatenate(trackgroups, axis=0)[jnp.newaxis, ...],
}


def sample_queries_first(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
) -> Mapping[str, chex.Array]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, use the first
visible point in each track as the query.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1]
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1]
"""

valid = np.sum(~target_occluded, axis=1) > 0
target_points = target_points[valid, :]
target_occluded = target_occluded[valid, :]

query_points = []
for i in range(target_points.shape[0]):
index = np.where(target_occluded[i] == 0)[0][0]
x, y = target_points[i, index, 0], target_points[i, index, 1]
query_points.append(np.array([index, y, x])) # [t, y, x]
query_points = np.stack(query_points, axis=0)

return {
'video': frames[np.newaxis, ...],
'query_points': query_points[np.newaxis, ...],
'target_points': target_points[np.newaxis, ...],
'occluded': target_occluded[np.newaxis, ...],
}


Expand All @@ -132,7 +166,7 @@ def create_jhmdb_dataset(jhmdb_path: str) -> Iterable[DatasetElement]:
continue

video_folder = '_'.join(file.split('_')[:-2])
for video in open(path.join(gt_dir, 'splits', file), 'r'):
for video in tf.io.gfile.GFile(path.join(gt_dir, 'splits', file), 'r'):
video, traintest = video.split()
video, _ = video.split('.')

Expand All @@ -153,14 +187,14 @@ def create_jhmdb_dataset(jhmdb_path: str) -> Iterable[DatasetElement]:
logging.info('skip %s', video)
continue

gt_pose = io.loadmat(open(joints, 'rb'))['pos_img']
gt_pose = io.loadmat(tf.io.gfile.GFile(joints, 'rb'))['pos_img']
gt_pose = np.transpose(gt_pose, [1, 2, 0])
frames = path.join(gt_dir, 'Rename_Images', video, '*.png')
framefil = glob.glob(frames)
framefil = tf.io.gfile.glob(frames)
framefil.sort()

def read_frame(f):
im = Image.open(open(f, 'rb'))
im = Image.open(tf.io.gfile.GFile(f, 'rb'))
im = im.convert('RGB')
return np.array(im.getdata()).reshape([im.size[1], im.size[0], 3])

Expand Down Expand Up @@ -260,7 +294,9 @@ def create_kubric_eval_dataset(mode: str) -> Iterable[DatasetElement]:
yield {'kubric': data}


def create_davis_dataset(davis_points_path: str) -> Iterable[DatasetElement]:
def create_davis_dataset(
davis_points_path: str,
query_mode: str = 'strided') -> Iterable[DatasetElement]:
"""Dataset for evaluating performance on DAVIS data."""
pickle_path = davis_points_path

Expand All @@ -285,14 +321,24 @@ def create_davis_dataset(davis_points_path: str) -> Iterable[DatasetElement]:
frames = frames.astype(np.float32) / 255. * 2. - 1.
target_points = davis_points_dataset[video_name]['points']
target_occ = davis_points_dataset[video_name]['occluded']
target_points *= np.array(
[tapnet_model.TRAIN_SIZE[2], tapnet_model.TRAIN_SIZE[1]])
converted = sample_and_pad(target_occ, target_points, frames)
target_points *= np.array([
tapnet_model.TRAIN_SIZE[2],
tapnet_model.TRAIN_SIZE[1],
])

if query_mode == 'strided':
converted = sample_queries_strided(target_occ, target_points, frames)
elif query_mode == 'first':
converted = sample_queries_first(target_occ, target_points, frames)
else:
raise ValueError(f'Unknown query mode {query_mode}.')

yield {'davis': converted}


def create_rgb_stacking_dataset(
robotics_points_path: str) -> Iterable[DatasetElement]:
robotics_points_path: str,
query_mode: str = 'strided') -> Iterable[DatasetElement]:
"""Dataset for evaluating performance on robotics data."""
pickle_path = robotics_points_path

Expand All @@ -306,26 +352,25 @@ def create_rgb_stacking_dataset(
target_occ = example['occluded']
target_points *= np.array(
[tapnet_model.TRAIN_SIZE[2], tapnet_model.TRAIN_SIZE[1]])
# Take the query points from the first frame
query_points = target_points[:, 0, ::-1]
query_points = np.concatenate(
[np.zeros_like(query_points[:, 0:1]), query_points], axis=-1)

converted = {
'video': frames[np.newaxis, ...],
'query_points': query_points[np.newaxis, ...],
'target_points': target_points[np.newaxis, ...],
'occluded': target_occ[np.newaxis, ...],
}
if query_mode == 'strided':
converted = sample_queries_strided(target_occ, target_points, frames)
elif query_mode == 'first':
converted = sample_queries_first(target_occ, target_points, frames)
else:
raise ValueError(f'Unknown query mode {query_mode}.')

yield {'robotics': converted}


def create_kinetics_dataset(kinetics_path: str) -> Iterable[DatasetElement]:
def create_kinetics_dataset(
kinetics_path: str,
query_mode: str = 'strided') -> Iterable[DatasetElement]:
"""Kinetics point tracking dataset."""
csv_path = path.join(kinetics_path, 'tapvid_kinetics.csv')

point_tracks_all = dict()
with tf.io.gfile.Open(csv_path) as f:
with tf.io.gfile.GFile(csv_path) as f:
reader = csv.reader(f, delimiter=',')
for row in reader:
youtube_id = row[0]
Expand All @@ -343,23 +388,19 @@ def create_kinetics_dataset(kinetics_path: str) -> Iterable[DatasetElement]:

point_tracks = np.stack(point_tracks_all[video_id], axis=0)
point_tracks = point_tracks.astype(np.float32)
if frames.shape[0] < point_tracks.shape[1]:
logging.info('Warning: short video!')
point_tracks = point_tracks[:, :frames.shape[0]]
point_tracks, occluded = point_tracks[..., 0:2], point_tracks[..., 2]
occluded = occluded > 0
target_points = point_tracks * np.array(
[tapnet_model.TRAIN_SIZE[2], tapnet_model.TRAIN_SIZE[1]])

# Find the query points from the first time when a point is visible.
query_points = []
for i in range(point_tracks.shape[0]):
index = np.where(occluded[i] == 0)[0][0]
x, y = target_points[i, index, 0], target_points[i, index, 1]
query_points.append(np.array([index, y, x])) # [t, y, x]
query_points = np.stack(query_points, axis=0)

converted = {
'video': frames[np.newaxis, ...],
'query_points': query_points[np.newaxis, ...],
'target_points': target_points[np.newaxis, ...],
'occluded': occluded[np.newaxis, ...],
}
if query_mode == 'strided':
converted = sample_queries_strided(occluded, target_points, frames)
elif query_mode == 'first':
converted = sample_queries_first(occluded, target_points, frames)
else:
raise ValueError(f'Unknown query mode {query_mode}.')

yield {'kinetics': converted}
3 changes: 2 additions & 1 deletion experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import tensorflow_datasets as tfds

from tapnet import supervised_point_prediction

from tapnet import tapnet_model
from tapnet import task
from tapnet.utils import experiment_utils as exputils
Expand Down Expand Up @@ -295,6 +294,7 @@ def create_dataset_generator(
ds = dataset_constructors[dset_name](**dset_kwargs)
if color_augmentation:
ds = exputils.add_default_data_augmentation(ds)

for dim in batch_dims[::-1]:
ds = ds.batch(dim)
np_ds = tfds.as_numpy(ds)
Expand Down Expand Up @@ -417,6 +417,7 @@ def main(_):
# Keep TF off the GPU; otherwise it hogs all the memory and leaves none for
# JAX.
tf.config.experimental.set_visible_devices([], 'GPU')
tf.config.experimental.set_visible_devices([], 'TPU')
platform.main(
Experiment,
sys.argv[1:],
Expand Down
15 changes: 8 additions & 7 deletions models/tsm_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
https://arxiv.org/pdf/1811.08383.pdf.
"""

from typing import Optional
from typing import Optional, Sequence, Union
from absl import logging

import chex
Expand Down Expand Up @@ -267,7 +267,7 @@ def __init__(
normalize_fn: Optional[NormalizeFn] = None,
depth: int = 18,
num_frames: int = 16,
channel_shift_fraction: float = 0.125,
channel_shift_fraction: Union[float, Sequence[float]] = 0.125,
width_mult: int = 1,
name: str = 'TSMResNetV2',
):
Expand All @@ -287,9 +287,12 @@ def __init__(
"""
super().__init__(name=name)

if not 0. <= channel_shift_fraction <= 1.0:
if isinstance(channel_shift_fraction, float):
channel_shift_fraction = [channel_shift_fraction] * 4

if not all([0. <= x <= 1.0 for x in channel_shift_fraction]):
raise ValueError(f'channel_shift_fraction ({channel_shift_fraction})'
' has to be in [0, 1].')
' all have to be in [0, 1].')

self._num_frames = num_frames

Expand All @@ -309,9 +312,7 @@ def __init__(
self._num_blocks = num_blocks[depth]

self._width_mult = width_mult
self._channel_shift_fraction = [
channel_shift_fraction, channel_shift_fraction, 0, 0,
]
self._channel_shift_fraction = channel_shift_fraction
self._normalize_fn = normalize_fn
self._use_bottleneck = (depth >= 50)

Expand Down
Loading

0 comments on commit 4438dde

Please sign in to comment.