-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_orca.py
34 lines (27 loc) · 934 Bytes
/
run_orca.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from baselines.orca import ORCA_Model
from utils.utils_common import define_callbacks, define_trainer
from utils.utils_config import load_config
from utils.utils_orca_datamodule import ORCADatamodule
def main():
# Load config
config, logger = load_config()
assert 'orca' in config.model_name
# DataModule
dm = ORCADatamodule(config)
dm.setup()
model = ORCA_Model(config)
# Trainer & callbacks
callbacks = define_callbacks(config, monitor='train_loss', mode='min')
trainer = define_trainer(config, logger, callbacks)
if not config.eval_only:
trainer.fit(model, dm.train_dataloader(), dm.test_dataloader())
# Test
trainer.test(
ckpt_path=trainer.checkpoint_callback.best_model_path,
dataloaders=dm.test_dataloader()
)
print(f'Run finished --> {config.run_name}')
if __name__ == '__main__':
print('Starting...')
main()
print('Done')