Skip to content

Commit

Permalink
Merge pull request mravanelli#145 from mennetob/dev_master_limitForwa…
Browse files Browse the repository at this point in the history
…rdingSubprocesses

limit forwarding subprocesses
  • Loading branch information
mravanelli authored Jul 24, 2019
2 parents 1192376 + 565d624 commit 1d95f12
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 5 deletions.
9 changes: 5 additions & 4 deletions data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ def _update_data(data_set, labs, fea_dict, fea, fea_index, data_set_fea, labs_fe
fea_index=fea_index+data_set_fea.shape[1]
fea_dict[fea].append(fea_index)
fea_dict[fea].append(fea_dict[fea][6]-fea_dict[fea][5])
elif cnt_fea==0 and (not cnt_fea==0):
elif cnt_fea==0 and (not cnt_lab==0):
labs=np.column_stack((labs,labs_fea))
elif (not cnt_fea==0) and cnt_fea==0:
elif (not cnt_fea==0) and cnt_lab==0:
data_set=np.column_stack((data_set,data_set_fea))
fea_dict[fea].append(fea_index)
fea_index=fea_index+data_set_fea.shape[1]
Expand Down Expand Up @@ -368,12 +368,13 @@ def _load_chunk_refac01(fea_scp,fea_opts,lab_folder,lab_opts,left,right,max_sequ
for lab in lab_dict.keys():
lab_folder, lab_opts = _get_lab_config_from_dict(lab_dict[lab], fea_only)
data_name_fea, data_set_fea, data_set_lab, data_end_index_fea, data_end_index_lab = _load_chunk_refac01(fea_scp, fea_opts, lab_folder, lab_opts, cw_left, cw_right, max_seq_length, output_folder, fea_only)
labs_fea, data_set_fea, data_end_index_fea, data_end_index_lab = _compensate_for_different_context_windows(data_set_fea, data_set_lab, cw_left_max, cw_left, cw_right_max, cw_right, data_end_index_fea, data_end_index_lab)
if sum([abs(e) for e in [cw_left_max, cw_right_max, cw_left, cw_right]]) != 0:
data_set_lab, data_set_fea, data_end_index_fea, data_end_index_lab = _compensate_for_different_context_windows(data_set_fea, data_set_lab, cw_left_max, cw_left, cw_right_max, cw_right, data_end_index_fea, data_end_index_lab)
if cnt_fea == 0 and cnt_lab == 0:
data_end_index_fea_ini = data_end_index_fea
data_end_index_lab_ini = data_end_index_lab
data_name = data_name_fea
data_set, labs, fea_dict, fea_index = _update_data(data_set, labs, fea_dict, fea, fea_index, data_set_fea, labs_fea, cnt_fea, cnt_lab)
data_set, labs, fea_dict, fea_index = _update_data(data_set, labs, fea_dict, fea, fea_index, data_set_fea, data_set_lab, cnt_fea, cnt_lab)
_check_consistency(data_name, data_name_fea, data_end_index_fea_ini, data_end_index_fea, data_end_index_lab_ini, data_end_index_lab)
cnt_lab=cnt_lab+1
cnt_fea=cnt_fea+1
Expand Down
17 changes: 17 additions & 0 deletions neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,23 @@ def _safe_log(inp, epsilon=1e-20):
out = log_mel_spec
return out

class channel_averaging(nn.Module):
def __init__(self, options,inp_dim):
super(channel_averaging, self).__init__()
self._use_cuda = strtobool(options['use_cuda'])
channel_weights = [float(e) for e in options['chAvg_channelWeights'].split(',')]
self._nr_of_channels = len(channel_weights)
numpy_weights = np.asarray(channel_weights, dtype=np.float32) * 1.0 / np.sum(channel_weights)
self._weights = torch.from_numpy(numpy_weights)
if self._use_cuda:
self._weights = self._weights.cuda()
self.out_dim = 1

def forward(self, x):
assert self._nr_of_channels == x.shape[-1]
out = torch.einsum('tbc,c->tb', x, self._weights).unsqueeze(-1)
return out

class liGRU(nn.Module):

def __init__(self, options,inp_dim):
Expand Down
4 changes: 4 additions & 0 deletions proto/channelAvg.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[proto]
chAvg_channelWeights=str


8 changes: 8 additions & 0 deletions run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def _get_nr_of_valid_per_epoch_from_config(config):
return True
return False

def _max_nr_of_parallel_forwarding_processes(config):
if 'max_nr_of_parallel_forwarding_processes' in config['forward']:
return int(config['forward']['max_nr_of_parallel_forwarding_processes'])
return -1

# Reading global cfg file (first argument-mandatory file)
cfg_file=sys.argv[1]
if not(os.path.exists(cfg_file)):
Expand Down Expand Up @@ -332,6 +337,9 @@ def _get_nr_of_valid_per_epoch_from_config(config):
data_end_index = {'fea': data_end_index_fea,'lab': data_end_index_lab}
p = multiprocessing.Process(target=run_nn, kwargs={'data_name': data_name, 'data_set': data_set, 'data_end_index': data_end_index, 'fea_dict': fea_dict, 'lab_dict': lab_dict, 'arch_dict': arch_dict, 'cfg_file': config_chunk_file, 'processed_first': False, 'next_config_file': None})
processes.append(p)
if _max_nr_of_parallel_forwarding_processes(config) != -1 and len(processes) > _max_nr_of_parallel_forwarding_processes(config):
processes[0].join()
del processes[0]
p.start()
else:
[data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict]=run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,config_chunk_file,processed_first,next_config_file)
Expand Down
9 changes: 8 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,12 @@ def _get_validation_data_for_chunks(fea_names, list_fea, N_chunks):
random.shuffle(full_list_fea_conc)
valid_chunks_fea=list(split_chunks(full_list_fea_conc,N_chunks))
return valid_chunks_fea
def _shuffle_forward_data(config):
if 'shuffle_forwarding_data' in config['forward']:
suffle_on_forwarding = strtobool(config['forward']['shuffle_forwarding_data'])
if not suffle_on_forwarding:
return False
return True

# splitting data into chunks (see out_folder/additional_files)
out_folder=config['exp']['out_folder']
Expand Down Expand Up @@ -1030,7 +1036,8 @@ def _get_validation_data_for_chunks(fea_names, list_fea, N_chunks):


# randomize the list
random.shuffle(full_list_fea_conc)
if _shuffle_forward_data(config):
random.shuffle(full_list_fea_conc)
forward_chunks_fea=list(split_chunks(full_list_fea_conc,N_chunks))


Expand Down

0 comments on commit 1d95f12

Please sign in to comment.