diff --git a/tapir_clustering.py b/tapir_clustering.py index 2a9dc97..c80f93e 100644 --- a/tapir_clustering.py +++ b/tapir_clustering.py @@ -39,10 +39,15 @@ class TrainingState(NamedTuple): step: jax.Array -def make_projection_matrix(pred_mat): +def make_projection_matrix(pred_mat, fourdof=True): """Convert predicted projection matrix parameters to a projection matrix.""" pred_mat = einshape('n(coi)->ncoi', pred_mat, o=3, i=4) + # This runs Gram-Schmidt to create an orthonormal matrix from the input 3x3 + # matrix that comes from a neural net. + # + # We run gradient clipping on the backward pass because the matrix might be + # badly conditioned. @jax.custom_vjp def f(x): return x @@ -55,12 +60,24 @@ def f_bwd(_, g): f.defvjp(f_fwd, f_bwd) pred_mat = f(pred_mat) - orth1 = jnp.ones_like(pred_mat[..., 0:1, :-1]) * jnp.array([0.0, 0.0, 1.0]) - orth2 = pred_mat[..., 1:2, :-1] * jnp.array([1.0, 1.0, 0.0]) + if fourdof: + orth1 = jnp.ones_like(pred_mat[..., 0:1, :-1]) * jnp.array([0.0, 0.0, 1.0]) + orth2 = pred_mat[..., 1:2, :-1] * jnp.array([1.0, 1.0, 0.0]) + else: + orth1 = pred_mat[..., 0:1,:-1] + orth1 = orth1 / jnp.sqrt( + jnp.maximum(jnp.sum(jnp.square(orth1), axis=-1, keepdims=True), 1e-12) + ) + orth2 = pred_mat[..., 1:2, :-1] + orth2 = orth2 - orth1 * jnp.sum(orth2 * orth1, axis=-1, keepdims=True) orth2 = orth2 / jnp.sqrt( jnp.maximum(jnp.sum(jnp.square(orth2), axis=-1, keepdims=True), 1e-12) ) - orth3 = pred_mat[..., 2:3, :-1] * jnp.array([1.0, 1.0, 0.0]) + orth3 = pred_mat[..., 2:3, :-1] + if fourdof: + orth3 *= jnp.array([1.0, 1.0, 0.0]) + else: + orth3 = orth3 - orth1 * jnp.sum(orth3 * orth1, axis=-1, keepdims=True) orth3 = orth3 - orth2 * jnp.sum(orth3 * orth2, axis=-1, keepdims=True) orth3 = orth3 / jnp.sqrt( jnp.maximum(jnp.sum(jnp.square(orth3), axis=-1, keepdims=True), 1e-12) @@ -74,7 +91,7 @@ def f_bwd(_, g): return pred_mat -def project(pred_mat, pos_pred): +def project(pred_mat, pos_pred, cam_focal_length): """Project 3D points to 2D, with penalties for depth out-of-range.""" pos_pred = jnp.concatenate( [pos_pred[..., :3], pos_pred[..., 0:1] * 0 + 1], axis=-1 @@ -84,7 +101,7 @@ def project(pred_mat, pos_pred): oob = jnp.maximum(pred_pos[..., 2:3] - 2.0, 0.0) + jnp.maximum( 0.5 - pred_pos[..., 2:3], 0.0 ) - all_pred = pred_pos[..., 0:2] / depth + all_pred = pred_pos[..., 0:2] * cam_focal_length / depth all_pred = ( all_pred + 0.1 * jax.random.normal(hk.next_rng_key(), shape=oob.shape) * oob @@ -100,6 +117,8 @@ def forward( num_cats=4, is_training=True, sequence_boundaries=tuple(), + fourdof=True, + cam_focal_length=1.0, ): """Model forward pass.""" @@ -213,12 +232,14 @@ def mul(mat): pred_mat_fork1 = state @ fork1 pred_mat_fork2 = state @ fork2 - pred_mat_base = make_projection_matrix(pred_mat_base)[fr_idx] - pred_mat_fork1 = make_projection_matrix(pred_mat_fork1)[fr_idx] - pred_mat_fork2 = make_projection_matrix(pred_mat_fork2)[fr_idx] + pred_mat_base = make_projection_matrix(pred_mat_base, fourdof)[fr_idx] + pred_mat_fork1 = make_projection_matrix(pred_mat_fork1, fourdof)[fr_idx] + pred_mat_fork2 = make_projection_matrix(pred_mat_fork2, fourdof)[fr_idx] if not is_training: - pred_pos_all, depth_all = project(pred_mat_base, pos_pred_base) + pred_pos_all, depth_all = project( + pred_mat_base, pos_pred_base, cam_focal_length + ) return pred_pos_all, depth_all else: return { @@ -239,6 +260,9 @@ def loss_fn( delete_mode=False, sequence_boundaries=tuple(), final_num_cats=28, + use_em=False, + fourdof=True, + cam_focal_length=1.0, ): """Computes the (scalar) LM loss on `data` w.r.t. params.""" pts, vis, _ = data @@ -252,10 +276,11 @@ def loss_fn( vis, num_cats=num_cats, sequence_boundaries=sequence_boundaries, + fourdof=fourdof, + cam_focal_length=cam_focal_length, ) pts = pts[pts_idx][:, fr_idx] - print(pts.shape) vis = vis[pts_idx][:, fr_idx] def do_fork(base, fork1, fork2, i, chunk=1): @@ -272,22 +297,35 @@ def do_delete(base, i, chunk=1): losses = [] sum_vis = jnp.sum(vis) + + # The following is the recursive cluster splitting and deleting algorithm: + # for every cluster, we 'split' it, creating 2 new clusters, or delete it. + # We optimize for every candidate cluster to split/delete, and choose + # the split/delete that minimizes the overall error. if delete_mode: - all_pred, _ = project(fwd['pred_mat_base'], fwd['pos_pred_base']) + all_pred, _ = project( + fwd['pred_mat_base'], fwd['pos_pred_base'], cam_focal_length + ) all_err = get_err(pts, vis, all_pred) for i in range(fwd['pred_mat_base'].shape[-3]): err_i = do_delete(all_err, i) - losses.append(loss_internal(err_i, sum_vis)) + losses.append(loss_internal(err_i, sum_vis, use_em=use_em)) else: - all_pred_base, _ = project(fwd['pred_mat_base'], fwd['pos_pred_base']) + all_pred_base, _ = project( + fwd['pred_mat_base'], fwd['pos_pred_base'], cam_focal_length + ) all_err_base = get_err(pts, vis, all_pred_base) - all_pred_fork1, _ = project(fwd['pred_mat_fork1'], fwd['pos_pred_fork1']) + all_pred_fork1, _ = project( + fwd['pred_mat_fork1'], fwd['pos_pred_fork1'], cam_focal_length + ) all_err_fork1 = get_err(pts, vis, all_pred_fork1) - all_pred_fork2, _ = project(fwd['pred_mat_fork2'], fwd['pos_pred_fork2']) + all_pred_fork2, _ = project( + fwd['pred_mat_fork2'], fwd['pos_pred_fork2'], cam_focal_length + ) all_err_fork2 = get_err(pts, vis, all_pred_fork2) for i in range(fwd['pred_mat_base'].shape[-3]): err_i = do_fork(all_err_base, all_err_fork1, all_err_fork2, i) - losses.append(loss_internal(err_i, sum_vis)) + losses.append(loss_internal(err_i, sum_vis, use_em=use_em)) if delete_mode: topk, _ = jax.lax.top_k(-jnp.array(losses), num_cats - final_num_cats + 3) accum_loss = jnp.mean(-topk) @@ -308,10 +346,41 @@ def get_err(pts, vis, all_pred): return jnp.sum(tmp, axis=1) -def loss_internal(err_summed, sum_vis): - min_loss = jnp.sum(jnp.min(err_summed, axis=1)) / sum_vis +def loss_internal(err_summed, sum_vis, use_em, em_variance=0.0001): + """Computes cluster assignments and loss given per-cluster error.""" + if use_em: + # In typical EM for gaussian mixture models, you keep the estimates of the + # prior probabilities for each mixture component (often called pi) across + # iterations. We could in principle do it that way for this code, but + # it's hard to say what we should do with these values for the + # 'splitting' and 'deleting' steps of the algorithm. Therefore, it's + # simpler to just estimate them on-the-fly based on the cluster + # membership probabilities. This needs to be done iteratively, + # but it converges extremely fast to something that's good enough. + err_normalized = err_summed - jnp.min(err_summed, axis=1, keepdims=True) + err_exp = jnp.exp(-err_normalized / em_variance) + wts = jnp.zeros([1, err_exp.shape[1]]) + 1.0 / err_exp.shape[1] + for _ in range(3): + wts = err_exp * wts / jnp.sum(err_exp * wts, axis=1, keepdims=True) + wts = jnp.sum(wts, axis=0, keepdims=True) + wts = jnp.maximum(wts, 1e-8) + wts = wts / jnp.sum(wts) + + min_loss = ( + -jnp.sum( + jax.scipy.special.logsumexp( + -err_summed / em_variance, b=wts, axis=1 + ) + ) + / sum_vis + * em_variance + ) + + return min_loss + else: + min_loss = jnp.sum(jnp.min(err_summed, axis=1)) / sum_vis - return min_loss + return min_loss def loss_fn_wrapper(*args, **kwargs): @@ -328,6 +397,9 @@ def update( sequence_boundaries=tuple(), optimiser=None, final_num_cats=28, + use_em=False, + fourdof=True, + cam_focal_length=1.0, ): """Does an SGD step and returns metrics.""" rng, new_rng = jax.random.split(state.rng) @@ -338,6 +410,9 @@ def update( delete_mode=delete_mode, sequence_boundaries=sequence_boundaries, final_num_cats=final_num_cats, + use_em=use_em, + fourdof=fourdof, + cam_focal_length=cam_focal_length, ), has_aux=True, ) @@ -431,6 +506,9 @@ def compute_clusters( final_num_cats=15, max_num_cats=25, low_visibility_threshold=0.1, + use_em=False, + fourdof=True, + cam_focal_length=1.0, ): """Compute clustering. @@ -450,6 +528,20 @@ def compute_clusters( beginning to merge. low_visibility_threshold: throw out tracks with less than this fraction of visible frames. + use_em: if True, use an EM-style soft cluster assignment. Not used in + RoboTAP, but empirically it can prevent the optimization from getting + stuck in local minima. + fourdof: if True (default), restrict the 3D transformations between frames + to be four degrees of freedom (i.e. depth, 2D translation, in-plane + rotation). Otherwise allow for full 6-degree-of-freedom transformations + between frames for objects. Note that 6DOF is likely to result in + objects being merged, because the model can use 3D rotation to explain + different 2D translations. + cam_focal_length: Camera focal length. Camera projection matrix is assumed + to have the form diag([f, f, 1.0]) @ [R,t] where R and t are the learned + rotation matrix and translation vector and f is camera_focal_length. The + optimization is typically not very sensitive to this; we used 1.0 for + RoboTAP, which is not correct for our cameras. Returns: A dict, where low-visibility points have been removed. "classes" is @@ -641,6 +733,9 @@ def fork_dict(param_dict, noise=0.0, mul=1.0): sequence_boundaries=sequence_boundaries, optimiser=optimiser, final_num_cats=final_num_cats, + use_em=use_em, + fourdof=fourdof, + cam_focal_length=cam_focal_length, ) ) need_compile = False