Skip to content

Commit

Permalink
enab. opt. tot haven optimizer for architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
menne committed Jul 16, 2019
1 parent 7f00285 commit 0cf520d
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ def _load_model_and_optimizer(fea_dict,model,config,arch_dict,use_cuda,multi_gpu
else:
checkpoint_load = torch.load(pt_file_arch, map_location='cpu')
nns[net].load_state_dict(checkpoint_load['model_par'])
optimizers[net].load_state_dict(checkpoint_load['optimizer_par'])
optimizers[net].param_groups[0]['lr']=float(config[arch_dict[net][0]]['arch_lr']) # loading lr of the cfg file for pt
if net in optimizers:
optimizers[net].load_state_dict(checkpoint_load['optimizer_par'])
optimizers[net].param_groups[0]['lr']=float(config[arch_dict[net][0]]['arch_lr']) # loading lr of the cfg file for pt
if multi_gpu:
nns[net] = torch.nn.DataParallel(nns[net])
return nns, costs, optimizers, inp_out_dict
Expand Down Expand Up @@ -186,7 +187,10 @@ def _save_model(to_do, nns, multi_gpu, optimizers, info_file, arch_dict):
checkpoint['model_par']=nns[net].module.state_dict()
else:
checkpoint['model_par']=nns[net].state_dict()
checkpoint['optimizer_par']=optimizers[net].state_dict()
if net in optimizers:
checkpoint['optimizer_par']=optimizers[net].state_dict()
else:
checkpoint['optimizer_par']=dict()
out_file=info_file.replace('.info','_'+arch_dict[net][0]+'.pkl')
torch.save(checkpoint, out_file)
def _get_dim_from_data_set(data_set_inp, data_set_ref):
Expand Down

0 comments on commit 0cf520d

Please sign in to comment.