Skip to content

Commit

Permalink
Distributed eval: SequentialDistributedSampler + gather all results (#…
Browse files Browse the repository at this point in the history
…4243)

* Distributed eval: SequentialDistributedSampler + gather all results

* For consistency only write to disk from world_master

Close #4272

* Working distributed eval

* Hook into scripts

* Fix #3721 again

* TPU.mesh_reduce: stay in tensor space

Thanks @jysohn23

* Just a small comment

* whitespace

* torch.hub: pip install packaging

* Add test scenarii
  • Loading branch information
julien-c committed May 19, 2020
1 parent 4c06893 commit 5e7fe8b
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/github-torch-hub.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install dependencies
run: |
pip install torch
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses packaging
- name: Torch hub list
run: |
Expand Down
13 changes: 7 additions & 6 deletions examples/language-modeling/run_language_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def main():

# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")

eval_output = trainer.evaluate()
Expand All @@ -260,11 +260,12 @@ def main():
result = {"perplexity": perplexity}

output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))

results.update(result)

Expand Down
15 changes: 8 additions & 7 deletions examples/multiple-choice/run_multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,20 @@ def compute_metrics(p: EvalPrediction) -> Dict:

# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")

result = trainer.evaluate()

output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))

results.update(result)
results.update(result)

return results

Expand Down
13 changes: 7 additions & 6 deletions examples/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def compute_metrics(p: EvalPrediction) -> Dict:

# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")

# Loop to handle MNLI double evaluation (matched, mis-matched)
Expand All @@ -181,11 +181,12 @@ def compute_metrics(p: EvalPrediction) -> Dict:
output_eval_file = os.path.join(
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
)
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))

results.update(result)

Expand Down
53 changes: 29 additions & 24 deletions examples/token-classification/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,23 @@ def compute_metrics(p: EvalPrediction) -> Dict:

# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")

result = trainer.evaluate()

output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))

results.update(result)

# Predict
if training_args.do_predict and training_args.local_rank in [-1, 0]:
if training_args.do_predict:
test_dataset = NerDataset(
data_dir=data_args.data_dir,
tokenizer=tokenizer,
Expand All @@ -265,26 +266,30 @@ def compute_metrics(p: EvalPrediction) -> Dict:
preds_list, _ = align_predictions(predictions, label_ids)

output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key, value in metrics.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_test_results_file, "w") as writer:
for key, value in metrics.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))

# Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
if trainer.is_world_master():
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)

return results

Expand Down
Loading

0 comments on commit 5e7fe8b

Please sign in to comment.