Skip to content

Commit

Permalink
Simplify a little bit finding the language, train & dev files for the…
Browse files Browse the repository at this point in the history
… charlm
  • Loading branch information
AngledLuffa committed Jun 28, 2023
1 parent 30d5bb6 commit b71397b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
3 changes: 3 additions & 0 deletions stanza/models/charlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def parse_args(args=None):

if args.wandb_name:
args.wandb = True
if not args.lang and args.shorthand:
args.lang = args.shorthand.split("_", 1)[0]
logger.info("Language not specified, but should be %s based on the shorthand of %s", args.lang, args.shorthand)

args = vars(args)
return args
Expand Down
18 changes: 11 additions & 7 deletions stanza/utils/training/run_charlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,27 @@ def run_treebank(mode, paths, treebank, short_name,
'--lang', short_language,
'--shorthand', short_name]
if mode == Mode.TRAIN:
train_args = ['--train_dir', train_dir,
'--eval_file', dev_file,
'--mode', 'train']
train_args = ['--mode', 'train']
if '--train_dir' not in extra_args:
train_args += ['--train_dir', train_dir]
if '--eval_file' not in extra_args:
train_args += ['--eval_file', dev_file]
train_args = train_args + default_args + extra_args
logger.info("Running train step with args: %s", train_args)
charlm.main(train_args)

if mode == Mode.SCORE_DEV:
dev_args = ['--eval_file', dev_file,
'--mode', 'predict']
dev_args = ['--mode', 'predict']
if '--eval_file' not in extra_args:
dev_args += ['--eval_file', dev_file]
dev_args = dev_args + default_args + extra_args
logger.info("Running dev step with args: %s", dev_args)
charlm.main(dev_args)

if mode == Mode.SCORE_TEST:
test_args = ['--eval_file', test_file,
'--mode', 'predict']
test_args = ['--mode', 'predict']
if '--eval_file' not in extra_args:
test_args += ['--eval_file', test_file]
test_args = test_args + default_args + extra_args
logger.info("Running test step with args: %s", test_args)
charlm.main(test_args)
Expand Down

0 comments on commit b71397b

Please sign in to comment.