Skip to content

Commit

Permalink
Add more options to tapir-based clustering algorithm.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 585919590
Change-Id: I0c394ea87d5423ba619fa77f470372319a5d935a
  • Loading branch information
cdoersch committed Nov 28, 2023
1 parent 668610a commit 4cb3543
Showing 1 changed file with 115 additions and 20 deletions.
135 changes: 115 additions & 20 deletions tapir_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@ class TrainingState(NamedTuple):
step: jax.Array


def make_projection_matrix(pred_mat):
def make_projection_matrix(pred_mat, fourdof=True):
"""Convert predicted projection matrix parameters to a projection matrix."""
pred_mat = einshape('n(coi)->ncoi', pred_mat, o=3, i=4)

# This runs Gram-Schmidt to create an orthonormal matrix from the input 3x3
# matrix that comes from a neural net.
#
# We run gradient clipping on the backward pass because the matrix might be
# badly conditioned.
@jax.custom_vjp
def f(x):
return x
Expand All @@ -55,12 +60,24 @@ def f_bwd(_, g):

f.defvjp(f_fwd, f_bwd)
pred_mat = f(pred_mat)
orth1 = jnp.ones_like(pred_mat[..., 0:1, :-1]) * jnp.array([0.0, 0.0, 1.0])
orth2 = pred_mat[..., 1:2, :-1] * jnp.array([1.0, 1.0, 0.0])
if fourdof:
orth1 = jnp.ones_like(pred_mat[..., 0:1, :-1]) * jnp.array([0.0, 0.0, 1.0])
orth2 = pred_mat[..., 1:2, :-1] * jnp.array([1.0, 1.0, 0.0])
else:
orth1 = pred_mat[..., 0:1,:-1]
orth1 = orth1 / jnp.sqrt(
jnp.maximum(jnp.sum(jnp.square(orth1), axis=-1, keepdims=True), 1e-12)
)
orth2 = pred_mat[..., 1:2, :-1]
orth2 = orth2 - orth1 * jnp.sum(orth2 * orth1, axis=-1, keepdims=True)
orth2 = orth2 / jnp.sqrt(
jnp.maximum(jnp.sum(jnp.square(orth2), axis=-1, keepdims=True), 1e-12)
)
orth3 = pred_mat[..., 2:3, :-1] * jnp.array([1.0, 1.0, 0.0])
orth3 = pred_mat[..., 2:3, :-1]
if fourdof:
orth3 *= jnp.array([1.0, 1.0, 0.0])
else:
orth3 = orth3 - orth1 * jnp.sum(orth3 * orth1, axis=-1, keepdims=True)
orth3 = orth3 - orth2 * jnp.sum(orth3 * orth2, axis=-1, keepdims=True)
orth3 = orth3 / jnp.sqrt(
jnp.maximum(jnp.sum(jnp.square(orth3), axis=-1, keepdims=True), 1e-12)
Expand All @@ -74,7 +91,7 @@ def f_bwd(_, g):
return pred_mat


def project(pred_mat, pos_pred):
def project(pred_mat, pos_pred, cam_focal_length):
"""Project 3D points to 2D, with penalties for depth out-of-range."""
pos_pred = jnp.concatenate(
[pos_pred[..., :3], pos_pred[..., 0:1] * 0 + 1], axis=-1
Expand All @@ -84,7 +101,7 @@ def project(pred_mat, pos_pred):
oob = jnp.maximum(pred_pos[..., 2:3] - 2.0, 0.0) + jnp.maximum(
0.5 - pred_pos[..., 2:3], 0.0
)
all_pred = pred_pos[..., 0:2] / depth
all_pred = pred_pos[..., 0:2] * cam_focal_length / depth
all_pred = (
all_pred
+ 0.1 * jax.random.normal(hk.next_rng_key(), shape=oob.shape) * oob
Expand All @@ -100,6 +117,8 @@ def forward(
num_cats=4,
is_training=True,
sequence_boundaries=tuple(),
fourdof=True,
cam_focal_length=1.0,
):
"""Model forward pass."""

Expand Down Expand Up @@ -213,12 +232,14 @@ def mul(mat):
pred_mat_fork1 = state @ fork1
pred_mat_fork2 = state @ fork2

pred_mat_base = make_projection_matrix(pred_mat_base)[fr_idx]
pred_mat_fork1 = make_projection_matrix(pred_mat_fork1)[fr_idx]
pred_mat_fork2 = make_projection_matrix(pred_mat_fork2)[fr_idx]
pred_mat_base = make_projection_matrix(pred_mat_base, fourdof)[fr_idx]
pred_mat_fork1 = make_projection_matrix(pred_mat_fork1, fourdof)[fr_idx]
pred_mat_fork2 = make_projection_matrix(pred_mat_fork2, fourdof)[fr_idx]

if not is_training:
pred_pos_all, depth_all = project(pred_mat_base, pos_pred_base)
pred_pos_all, depth_all = project(
pred_mat_base, pos_pred_base, cam_focal_length
)
return pred_pos_all, depth_all
else:
return {
Expand All @@ -239,6 +260,9 @@ def loss_fn(
delete_mode=False,
sequence_boundaries=tuple(),
final_num_cats=28,
use_em=False,
fourdof=True,
cam_focal_length=1.0,
):
"""Computes the (scalar) LM loss on `data` w.r.t. params."""
pts, vis, _ = data
Expand All @@ -252,10 +276,11 @@ def loss_fn(
vis,
num_cats=num_cats,
sequence_boundaries=sequence_boundaries,
fourdof=fourdof,
cam_focal_length=cam_focal_length,
)

pts = pts[pts_idx][:, fr_idx]
print(pts.shape)
vis = vis[pts_idx][:, fr_idx]

def do_fork(base, fork1, fork2, i, chunk=1):
Expand All @@ -272,22 +297,35 @@ def do_delete(base, i, chunk=1):

losses = []
sum_vis = jnp.sum(vis)

# The following is the recursive cluster splitting and deleting algorithm:
# for every cluster, we 'split' it, creating 2 new clusters, or delete it.
# We optimize for every candidate cluster to split/delete, and choose
# the split/delete that minimizes the overall error.
if delete_mode:
all_pred, _ = project(fwd['pred_mat_base'], fwd['pos_pred_base'])
all_pred, _ = project(
fwd['pred_mat_base'], fwd['pos_pred_base'], cam_focal_length
)
all_err = get_err(pts, vis, all_pred)
for i in range(fwd['pred_mat_base'].shape[-3]):
err_i = do_delete(all_err, i)
losses.append(loss_internal(err_i, sum_vis))
losses.append(loss_internal(err_i, sum_vis, use_em=use_em))
else:
all_pred_base, _ = project(fwd['pred_mat_base'], fwd['pos_pred_base'])
all_pred_base, _ = project(
fwd['pred_mat_base'], fwd['pos_pred_base'], cam_focal_length
)
all_err_base = get_err(pts, vis, all_pred_base)
all_pred_fork1, _ = project(fwd['pred_mat_fork1'], fwd['pos_pred_fork1'])
all_pred_fork1, _ = project(
fwd['pred_mat_fork1'], fwd['pos_pred_fork1'], cam_focal_length
)
all_err_fork1 = get_err(pts, vis, all_pred_fork1)
all_pred_fork2, _ = project(fwd['pred_mat_fork2'], fwd['pos_pred_fork2'])
all_pred_fork2, _ = project(
fwd['pred_mat_fork2'], fwd['pos_pred_fork2'], cam_focal_length
)
all_err_fork2 = get_err(pts, vis, all_pred_fork2)
for i in range(fwd['pred_mat_base'].shape[-3]):
err_i = do_fork(all_err_base, all_err_fork1, all_err_fork2, i)
losses.append(loss_internal(err_i, sum_vis))
losses.append(loss_internal(err_i, sum_vis, use_em=use_em))
if delete_mode:
topk, _ = jax.lax.top_k(-jnp.array(losses), num_cats - final_num_cats + 3)
accum_loss = jnp.mean(-topk)
Expand All @@ -308,10 +346,41 @@ def get_err(pts, vis, all_pred):
return jnp.sum(tmp, axis=1)


def loss_internal(err_summed, sum_vis):
min_loss = jnp.sum(jnp.min(err_summed, axis=1)) / sum_vis
def loss_internal(err_summed, sum_vis, use_em, em_variance=0.0001):
"""Computes cluster assignments and loss given per-cluster error."""
if use_em:
# In typical EM for gaussian mixture models, you keep the estimates of the
# prior probabilities for each mixture component (often called pi) across
# iterations. We could in principle do it that way for this code, but
# it's hard to say what we should do with these values for the
# 'splitting' and 'deleting' steps of the algorithm. Therefore, it's
# simpler to just estimate them on-the-fly based on the cluster
# membership probabilities. This needs to be done iteratively,
# but it converges extremely fast to something that's good enough.
err_normalized = err_summed - jnp.min(err_summed, axis=1, keepdims=True)
err_exp = jnp.exp(-err_normalized / em_variance)
wts = jnp.zeros([1, err_exp.shape[1]]) + 1.0 / err_exp.shape[1]
for _ in range(3):
wts = err_exp * wts / jnp.sum(err_exp * wts, axis=1, keepdims=True)
wts = jnp.sum(wts, axis=0, keepdims=True)
wts = jnp.maximum(wts, 1e-8)
wts = wts / jnp.sum(wts)

min_loss = (
-jnp.sum(
jax.scipy.special.logsumexp(
-err_summed / em_variance, b=wts, axis=1
)
)
/ sum_vis
* em_variance
)

return min_loss
else:
min_loss = jnp.sum(jnp.min(err_summed, axis=1)) / sum_vis

return min_loss
return min_loss


def loss_fn_wrapper(*args, **kwargs):
Expand All @@ -328,6 +397,9 @@ def update(
sequence_boundaries=tuple(),
optimiser=None,
final_num_cats=28,
use_em=False,
fourdof=True,
cam_focal_length=1.0,
):
"""Does an SGD step and returns metrics."""
rng, new_rng = jax.random.split(state.rng)
Expand All @@ -338,6 +410,9 @@ def update(
delete_mode=delete_mode,
sequence_boundaries=sequence_boundaries,
final_num_cats=final_num_cats,
use_em=use_em,
fourdof=fourdof,
cam_focal_length=cam_focal_length,
),
has_aux=True,
)
Expand Down Expand Up @@ -431,6 +506,9 @@ def compute_clusters(
final_num_cats=15,
max_num_cats=25,
low_visibility_threshold=0.1,
use_em=False,
fourdof=True,
cam_focal_length=1.0,
):
"""Compute clustering.
Expand All @@ -450,6 +528,20 @@ def compute_clusters(
beginning to merge.
low_visibility_threshold: throw out tracks with less than this fraction of
visible frames.
use_em: if True, use an EM-style soft cluster assignment. Not used in
RoboTAP, but empirically it can prevent the optimization from getting
stuck in local minima.
fourdof: if True (default), restrict the 3D transformations between frames
to be four degrees of freedom (i.e. depth, 2D translation, in-plane
rotation). Otherwise allow for full 6-degree-of-freedom transformations
between frames for objects. Note that 6DOF is likely to result in
objects being merged, because the model can use 3D rotation to explain
different 2D translations.
cam_focal_length: Camera focal length. Camera projection matrix is assumed
to have the form diag([f, f, 1.0]) @ [R,t] where R and t are the learned
rotation matrix and translation vector and f is camera_focal_length. The
optimization is typically not very sensitive to this; we used 1.0 for
RoboTAP, which is not correct for our cameras.
Returns:
A dict, where low-visibility points have been removed. "classes" is
Expand Down Expand Up @@ -641,6 +733,9 @@ def fork_dict(param_dict, noise=0.0, mul=1.0):
sequence_boundaries=sequence_boundaries,
optimiser=optimiser,
final_num_cats=final_num_cats,
use_em=use_em,
fourdof=fourdof,
cam_focal_length=cam_focal_length,
)
)
need_compile = False
Expand Down

0 comments on commit 4cb3543

Please sign in to comment.