forked from hitachinsk/FGT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·70 lines (57 loc) · 2.34 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from utils.dist import *
from parse import *
from utils.util import find_free_port
import torch.multiprocessing as mp
import torch.distributed
from importlib import import_module
from flow_inputs import args_parser
def main_worker(rank, opt):
if 'local_rank' not in opt:
opt['local_rank'] = opt['global_rank'] = rank
if opt['distributed']:
torch.cuda.set_device(int(opt['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=opt['init_method'],
world_size=opt['world_size'],
rank=opt['global_rank'],
group_name='mtorch')
print('using GPU {}-{} for training'.format(
int(opt['global_rank']), int(opt['local_rank'])))
if torch.cuda.is_available():
opt['device'] = torch.device("cuda:{}".format(opt['local_rank']))
else:
opt['device'] = 'cpu'
pkg = import_module('networks.{}'.format(opt['network']))
trainer = pkg.Network(opt, rank)
trainer.train()
def main(args_obj):
opt = parse(args_obj)
opt['world_size'] = get_world_size()
free_port = find_free_port()
master_ip = get_master_ip()
opt['init_method'] = "tcp://{}:{}".format(master_ip, free_port)
opt['distributed'] = True if opt['world_size'] > 1 else False
print(f'World size is: {opt["world_size"]}, and init_method is: {opt["init_method"]}')
print('Import network module: ', opt['network'])
# dataset file names
if opt['gen_state'] != '':
opt['path']['gen_state'] = opt['gen_state']
if opt['opt_state'] != '':
opt['path']['opt_state'] = opt['opt_state']
if args.finetune == 1:
opt['finetune'] = True
else:
opt['finetune'] = False
print(f'model is: {opt["model"]}')
if get_master_ip() == "127.0.0.1":
# localhost
mp.spawn(main_worker, nprocs=opt['world_size'], args=(opt,))
else:
# multiple processes should be launched by openmpi
opt['local_rank'] = get_local_rank()
opt['global_rank'] = get_global_rank()
main_worker(-1, opt)
if __name__ == '__main__':
args = args_parser()
args_obj = vars(args)
main(args_obj)