Skip to content

Commit

Permalink
adding scheduler for max_seq_len_tr
Browse files Browse the repository at this point in the history
  • Loading branch information
mravanelli committed Feb 22, 2019
1 parent 484e343 commit 5d78591
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ python run_exp.py cfg/TIMIT_baselines/TIMIT_MLP_fbank_prod.cfg --dataset4,fea,0,
This command will internally alter the configuration file with your specified paths, and run and your defined features! Note that passing long arguments to the run_exp.py script requires a specific notation. *--dataset4* specifies the name of the created section, *fea* is the name of the higher level field, *fea_lst* or *lab_graph* are the name of the lowest level field you want to change. The *0* is here to indicate which lowest level field you want to alter, indeed some configuration files may contain multiple *lab_graph* per dataset! Therefore, *0* indicates the first occurrence, *1* the second ... Paths MUST be encapsulated by " " to be interpreted as full strings! Note that you need to alter the *data_name* and *forward_with* fields if you don't want different .wav files transcriptions to erase each other (decoding files are stored accordingly to the field*data_name*). ``` --dataset4,data_name=MyNewName --data_use,forward_with=MyNewName ```.

## Batch size, learning rate, and dropout scheduler
In order to give users more flexibility, the latest version of PyTorch-Kaldi supports scheduling of the batch size, learning rate, and dropout factor.
In order to give users more flexibility, the latest version of PyTorch-Kaldi supports scheduling of the batch size, max_seq_length_train, learning rate, and dropout factor.
This means that it is now possible to change these values during training. To support this feature, we implemented the following formalisms within the config files:
```
batch_size_train = 128*12 | 64*10 | 32*2
Expand Down
2 changes: 1 addition & 1 deletion cfg/TIMIT_baselines/TIMIT_MLP_mfcc_basic_flex.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ forward_with = TIMIT_test

[batches]
batch_size_train = 128*12 | 64*10 | 32*2
max_seq_length_train = 1000
max_seq_length_train = 1000*18 | 500*6
increase_seq_length_train = False
start_seq_len_train = 100
multply_factor_seq_len_train = 2
Expand Down
2 changes: 1 addition & 1 deletion proto/global.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ forward_with=list_str

[batches]
batch_size_train=list_str
max_seq_length_train=int(20,inf)
max_seq_length_train=list_str
increase_seq_length_train=Bool
start_seq_len_train=int(20,inf)
multply_factor_seq_len_train=int(0,inf)
Expand Down
2 changes: 1 addition & 1 deletion run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
tr_data_lst=config['data_use']['train_with'].split(',')
valid_data_lst=config['data_use']['valid_with'].split(',')
forward_data_lst=config['data_use']['forward_with'].split(',')
max_seq_length_train=int(config['batches']['max_seq_length_train'])
max_seq_length_train=config['batches']['max_seq_length_train']
forward_save_files=list(map(strtobool,config['forward']['save_out_file'].split(',')))


Expand Down
17 changes: 10 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def create_configs(config):
N_ep_str_format='0'+str(max(math.ceil(np.log10(N_ep)),1))+'d'
tr_data_lst=config['data_use']['train_with'].split(',')
valid_data_lst=config['data_use']['valid_with'].split(',')
max_seq_length_train=int(config['batches']['max_seq_length_train'])
max_seq_length_train=config['batches']['max_seq_length_train']
forward_data_lst=config['data_use']['forward_with'].split(',')


Expand All @@ -693,9 +693,10 @@ def create_configs(config):
# Read the batch size string
batch_size_tr_str=config['batches']['batch_size_train']
batch_size_tr_arr=expand_str_ep(batch_size_tr_str,'int',N_ep,'|','*')



# Read the max_seq_length_train
max_seq_length_tr_arr=expand_str_ep(max_seq_length_train,'int',N_ep,'|','*')


cfg_file_proto=config['cfg_proto']['cfg_proto']
[config,name_data,name_arch]=check_cfg(cfg_file,config,cfg_file_proto)
Expand Down Expand Up @@ -734,8 +735,7 @@ def create_configs(config):

if strtobool(config['batches']['increase_seq_length_train']):
max_seq_length_train_curr=int(config['batches']['start_seq_len_train'])
else:
max_seq_length_train_curr=max_seq_length_train


for ep in range(N_ep):

Expand Down Expand Up @@ -767,6 +767,9 @@ def create_configs(config):
config_chunk_file=out_folder+'/exp_files/train_'+tr_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'.cfg'
lst_chunk_file.write(config_chunk_file+'\n')

if strtobool(config['batches']['increase_seq_length_train'])==False:
max_seq_length_train_curr=int(max_seq_length_tr_arr[ep])

# Write chunk-specific cfg file
write_cfg_chunk(cfg_file,config_chunk_file,cfg_file_proto_chunk,pt_files,lst_file,info_file,'train',tr_data,lr,max_seq_length_train_curr,name_data,ep,ck,batch_size_tr_arr[ep],drop_rates)

Expand Down Expand Up @@ -795,8 +798,8 @@ def create_configs(config):
# if needed, update sentence_length
if strtobool(config['batches']['increase_seq_length_train']):
max_seq_length_train_curr=max_seq_length_train_curr*int(config['batches']['multply_factor_seq_len_train'])
if max_seq_length_train_curr>max_seq_length_train:
max_seq_length_train_curr=max_seq_length_train
if max_seq_length_train_curr>int(max_seq_length_tr_arr[ep]):
max_seq_length_train_curr=int(max_seq_length_tr_arr[ep])


for forward_data in forward_data_lst:
Expand Down

0 comments on commit 5d78591

Please sign in to comment.