forked from google-deepmind/tapnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tapnet_config.py
125 lines (115 loc) · 4.7 KB
/
tapnet_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Default config to train the TapNet."""
from jaxline import base_config
from ml_collections import config_dict
# We define the experiment launch config in the same file as the experiment to
# keep things self-contained in a single file.
def get_config() -> config_dict.ConfigDict:
"""Return config object for training."""
config = base_config.get_base_config()
# Experiment config.
config.training_steps = 100000
# NOTE: duplicates not allowed.
config.shared_module_names = ('tapnet_model',)
config.dataset_names = ('kubric',)
# Note: eval modes must always start with 'eval_'.
config.eval_modes = (
'eval_davis_points',
'eval_jhmdb',
'eval_robotics_points',
'eval_kinetics_points',
)
config.checkpoint_dir = '/tmp/tapnet_training/'
config.evaluate_every = 10000
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
sweep_name='default_sweep',
save_final_checkpoint_as_npy=True,
# `enable_double_transpose` should always be false when using 1D.
# For other D It is also completely untested and very unlikely
# to work.
optimizer=dict(
base_lr=2e-3,
max_norm=-1, # < 0 to turn off.
weight_decay=1e-2,
schedule_type='cosine',
cosine_decay_kwargs=dict(
init_value=0.0,
warmup_steps=5000,
end_value=0.0,
),
optimizer='adam',
# Optimizer-specific kwargs.
adam_kwargs=dict(
b1=0.9,
b2=0.95,
eps=1e-8,
),
),
fast_variables=tuple(),
shared_modules=dict(
shared_module_names=config.get_oneway_ref(
'shared_module_names',
),
tapnet_model_kwargs=dict(),
),
datasets=dict(
dataset_names=config.get_oneway_ref('dataset_names'),
kubric_kwargs=dict(
batch_dims=8,
shuffle_buffer_size=128,
train_size=(256, 256),
),
),
supervised_point_prediction_kwargs=dict(
prediction_algo='cost_volume_regressor',
),
checkpoint_dir=config.get_oneway_ref('checkpoint_dir'),
evaluate_every=config.get_oneway_ref('evaluate_every'),
eval_modes=config.get_oneway_ref('eval_modes'),
# If true, run evaluate() on the experiment once before
# you load a checkpoint.
# This is useful for getting initial values of metrics
# at random weights, or when debugging locally if you
# do not have any train job running.
davis_points_path='',
jhmdb_path='',
robotics_points_path='',
training=dict(
# Note: to sweep n_training_steps, DO NOT sweep these
# fields directly. Instead sweep config.training_steps.
# Otherwise, decay/stopping logic
# is not guaranteed to be consistent.
n_training_steps=config.get_oneway_ref('training_steps'),
),
inference=dict(
input_video_path='',
output_video_path='',
resize_height=256, # video height resized to before inference
resize_width=256, # video width resized to before inference
num_points=20, # number of random points to sample
),
)
)
)
# Set up where to store the resulting model.
config.train_checkpoint_all_hosts = False
config.save_checkpoint_interval = 10
config.eval_initial_weights = True
# Prevents accidentally setting keys that aren't recognized (e.g. in tests).
config.lock()
return config