Skip to content

Commit

Permalink
Tidying TAP-Net for open-source release.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 477124237
Change-Id: Ie7ce7a0edb97b2a6ab1153980f17c78288a97903
  • Loading branch information
cdoersch authored and copybara-github committed Sep 27, 2022
1 parent 9cc8e14 commit 6de0c22
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 115 deletions.
3 changes: 1 addition & 2 deletions configs/tapnet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================

"""Default config to train the TapNet."""
# import jax
from jaxline import base_config
from ml_collections import config_dict

Expand All @@ -23,7 +22,7 @@

# We define the experiment launch config in the same file as the experiment to
# keep things self-contained in a single file.
def get_config() -> config_dict.ConfigDict(): # pytype: disable=invalid-annotation
def get_config() -> config_dict.ConfigDict:
"""Return config object for training."""
config = base_config.get_base_config()

Expand Down
40 changes: 14 additions & 26 deletions evaluation_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import random
from typing import Iterable, Optional, Mapping, Union

from absl import flags
from absl import logging
import chex
import jax
Expand All @@ -36,12 +35,6 @@

from tapnet import tapnet_model

from kubric.challenges.point_tracking import dataset

# These three are questionable

FLAGS = flags.FLAGS


DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]

Expand Down Expand Up @@ -82,7 +75,6 @@ def sample_and_pad(
pad_extra_frames: the number of pad frames that were added to reach
num_frames.
"""
query_stride = 5
tracks = []
occs = []
queries = []
Expand Down Expand Up @@ -111,7 +103,7 @@ def frame_pad(x):
pads[1] = (0, num_frames - x.shape[1]) # pylint: disable=cell-var-from-loop
return jnp.pad(x, pads, mode='edge')

converted = {
return {
'video':
frame_pad(frames[jnp.newaxis, ...]),
'query_points':
Expand All @@ -125,12 +117,11 @@ def frame_pad(x):
'pad_extra_frames':
num_frames - target_occluded.shape[1],
}
return converted


def create_jhmdb_dataset() -> Iterable[DatasetElement]:
def create_jhmdb_dataset(jhmdb_path: str) -> Iterable[DatasetElement]:
"""JHMDB dataset, including fields required for PCK evaluation."""
gt_dir = FLAGS.config.jhmdb_path
gt_dir = jhmdb_path
videos = []
for file in tf.io.gfile.listdir(path.join(gt_dir, 'splits')):
# JHMDB file containing the first split, which is standard for this type of
Expand Down Expand Up @@ -173,9 +164,7 @@ def read_frame(f):

frames = [read_frame(x) for x in framefil]
frames = np.stack(frames)
num_frames = frames.shape[0]
height = frames.shape[1]
width = frames.shape[2]
num_frames, height, width, _ = frames.shape
invalid_x = np.logical_or(
gt_pose[:, 0:1, 0] < 0,
gt_pose[:, 0:1, 0] >= width,
Expand All @@ -188,16 +177,14 @@ def read_frame(f):
invalid = np.tile(invalid, [1, gt_pose.shape[1]])
invalid = invalid[:, :, jnp.newaxis].astype(np.float32)
gt_pose_orig = gt_pose
gt_pose = gt_pose * (1.0 - invalid) - invalid
# Set invalid poses to -1 (outside the frame)
gt_pose = (1. - invalid) * gt_pose + invalid * (-1.)

frames = np.array(
jax.jit(
functools.partial(
jax.image.resize,
shape=[
num_frames, tapnet_model.TRAIN_SIZE[1],
tapnet_model.TRAIN_SIZE[2], 3
],
shape=[num_frames, *tapnet_model.TRAIN_SIZE[1:4]],
method='bilinear',
))(frames))
frames = frames / (255. / 2.) - 1.
Expand Down Expand Up @@ -240,12 +227,12 @@ def create_kubric_eval_train_dataset(
vflip='vflip' in mode,
random_crop=False)

num_returned = -1
num_returned = 0

for data in res[0]():
num_returned += 1
if num_returned >= max_dataset_size:
break
num_returned += 1
yield {'kubric': data}


Expand All @@ -265,9 +252,9 @@ def create_kubric_eval_dataset(mode: str) -> Iterable[DatasetElement]:
yield {'kubric': data}


def create_davis_dataset() -> Iterable[DatasetElement]:
def create_davis_dataset(davis_points_path: str) -> Iterable[DatasetElement]:
"""Dataset for evaluating performance on DAVIS data."""
pickle_path = FLAGS.config.davis_points_path
pickle_path = davis_points_path

with tf.io.gfile.GFile(pickle_path, 'rb') as f:
davis_points_dataset = pickle.load(f)
Expand Down Expand Up @@ -296,9 +283,10 @@ def create_davis_dataset() -> Iterable[DatasetElement]:
yield {'davis': converted}


def create_rgb_stacking_dataset() -> Iterable[DatasetElement]:
def create_rgb_stacking_dataset(
robotics_points_path: str) -> Iterable[DatasetElement]:
"""Dataset for evaluating performance on robotics data."""
pickle_path = FLAGS.config.robotics_points_path
pickle_path = robotics_points_path

with tf.io.gfile.GFile(pickle_path, 'rb') as f:
robotics_points_dataset = pickle.load(f)
Expand Down
26 changes: 14 additions & 12 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jaxline import experiment
from jaxline import platform
from jaxline import utils
from kubric.challenges.point_tracking import dataset
from ml_collections import config_dict

import numpy as np
Expand All @@ -37,16 +38,12 @@
import tensorflow_datasets as tfds

from tapnet import supervised_point_prediction
from kubric.challenges.point_tracking import dataset

from tapnet import tapnet_model
from tapnet import task
from tapnet.utils import experiment_utils as exputils


FLAGS = flags.FLAGS


class Experiment(experiment.AbstractExperiment):
"""TAPNet experiment.
Expand Down Expand Up @@ -111,17 +108,24 @@ def __init__(
self._state = None
self._opt_state = None

self._optimizer = None

# Input pipelines.
self._train_input = None
self._eval_input = None

self.point_prediction = supervised_point_prediction.SupervisedPointPrediction(
config,
**config.supervised_point_prediction_kwargs)

def forward(*args, **kwargs):
def forward(*args, is_training=True, **kwargs):
shared_modules = self._construct_shared_modules()
return self.point_prediction.forward_fn(
*args, shared_modules=shared_modules, **kwargs)
*args,
shared_modules=shared_modules,
is_training=is_training,
**kwargs,
)

self._transform = hk.transform_with_state(forward)

Expand Down Expand Up @@ -184,8 +188,8 @@ def step(

scalars = utils.get_first(scalars)

if (global_step % FLAGS.config.evaluate_every) == 0:
for mode in FLAGS.config.eval_modes:
if (global_step % self.config.evaluate_every) == 0:
for mode in self.config.eval_modes:
eval_scalars = self.evaluate(global_step, rng=rng, mode=mode)
scalars.update(eval_scalars)

Expand Down Expand Up @@ -274,7 +278,7 @@ def _build_train_input(

def create_dataset_generator(
self,
dataset_constructors: Mapping[str, Callable[..., tf.Dataset]],
dataset_constructors: Mapping[str, Callable[..., tf.data.Dataset]],
dset_name: str,
) -> Iterator[Mapping[str, np.ndarray]]:
# Batch data on available devices.
Expand Down Expand Up @@ -398,12 +402,10 @@ def evaluate(
wrapped_forward_fn=forward_fn,
mode=mode,
)
eval_scalars = {
return {
f'eval/{mode}/{key}': value for key, value in eval_scalars.items()
}

return eval_scalars


def main(_):
flags.mark_flag_as_required('config')
Expand Down
27 changes: 11 additions & 16 deletions models/tsm_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(
name: The name of the module.
"""
super().__init__(name=name)
self._output_channels = output_channels if use_bottleneck else output_channels // 4
self._output_channels = (
output_channels if use_bottleneck else output_channels // 4)
self._bottleneck_channels = output_channels // 4
self._stride = stride
self._rate = rate
Expand Down Expand Up @@ -118,8 +119,7 @@ def __call__(
with_bias=False,
padding='SAME',
name='shortcut_conv',
)(
preact)
)(preact)
else:
shortcut = inputs

Expand All @@ -139,8 +139,7 @@ def __call__(
with_bias=False,
padding='SAME',
name='conv_0',
)(
preact)
)(preact)

if self._use_bottleneck:
# Second convolution.
Expand All @@ -155,8 +154,7 @@ def __call__(
with_bias=False,
padding='SAME',
name='conv_1',
)(
residual)
)(residual)

# Third convolution.
if self._normalize_fn is not None:
Expand All @@ -169,8 +167,7 @@ def __call__(
with_bias=False,
padding='SAME',
name='conv_2',
)(
residual)
)(residual)

# NOTE: we do not use block multiplier.
output = shortcut + residual
Expand Down Expand Up @@ -239,9 +236,9 @@ def __call__(
for idx_block in range(self._num_blocks):
net = TSMResNetBlock(
self._output_channels,
stride=self._stride if idx_block == 0 else 1,
rate=max(self._rate // 2, 1) if idx_block == 0 else self._rate,
use_projection=idx_block == 0,
stride=(self._stride if idx_block == 0 else 1),
rate=(max(self._rate // 2, 1) if idx_block == 0 else self._rate),
use_projection=(idx_block == 0),
normalize_fn=self._normalize_fn,
tsm_mode=self._tsm_mode,
channel_shift_fraction=self._channel_shift_fraction,
Expand Down Expand Up @@ -389,14 +386,12 @@ def __call__(
with_bias=False,
name=end_point,
padding='SAME',
)(
inputs)
)(inputs)
net = hk.MaxPool(
window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1),
padding='SAME',
)(
net)
)(net)
if self._final_endpoint == end_point:
net = tsmu.prepare_outputs(net, tsm_mode, num_frames, reduce_mean=False)
return net
Expand Down
25 changes: 12 additions & 13 deletions optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

"""Optimizer utils."""
from typing import Callable, List, NamedTuple, Optional, Text
from typing import Callable, Sequence, NamedTuple, Optional, Text

import haiku as hk
import jax
Expand All @@ -25,7 +25,7 @@


def _weight_decay_exclude(
exclude_names: Optional[List[Text]] = None
exclude_names: Optional[Sequence[Text]] = None
) -> Callable[[str, str, jnp.ndarray], bool]:
"""Logic for deciding which parameters to include for weight decay..
Expand All @@ -34,24 +34,22 @@ def _weight_decay_exclude(
by default.
Returns:
A predicate that returns True for params that need to be excluded from
A predicate that returns False for params that need to be excluded from
weight_decay.
"""
# By default weight_decay the weights but not the biases.
if not exclude_names:
if exclude_names is None:
exclude_names = ["b"]

def exclude(module_name: Text, name: Text, value: jnp.array):
def include(module_name: Text, name: Text, value: jnp.array):
del value
# Do not weight decay the parameters of normalization blocks.
if any([norm_name in module_name for norm_name in NORM_NAMES]):
return False
elif name not in exclude_names:
return True
else:
return False
return name not in exclude_names

return exclude
return include


class AddWeightDecayState(NamedTuple):
Expand All @@ -60,7 +58,8 @@ class AddWeightDecayState(NamedTuple):

def add_weight_decay(
weight_decay: float,
exclude_names: Optional[List[Text]] = None) -> optax.GradientTransformation:
exclude_names: Optional[Sequence[Text]] = None
) -> optax.GradientTransformation:
"""Add parameter scaled by `weight_decay` to the `updates`.
Same as optax.additive_weight_decay but can exclude some parameters.
Expand All @@ -78,10 +77,10 @@ def init_fn(_):
return AddWeightDecayState()

def update_fn(updates, state, params):
exclude = _weight_decay_exclude(exclude_names=exclude_names)
include = _weight_decay_exclude(exclude_names=exclude_names)

u_in, u_ex = hk.data_structures.partition(exclude, updates)
p_in, _ = hk.data_structures.partition(exclude, params)
u_in, u_ex = hk.data_structures.partition(include, updates)
p_in, _ = hk.data_structures.partition(include, params)
u_in = jax.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in)
updates = hk.data_structures.merge(u_ex, u_in)
return updates, state
Expand Down
Loading

0 comments on commit 6de0c22

Please sign in to comment.