From 9abcd9e3afa4760ab9c2e3845e82909e992fd995 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Mon, 4 Jan 2021 19:49:46 -0800 Subject: [PATCH] add test for export_model --- tests/proj/main/test_export_model.py | 33 ++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/proj/main/test_export_model.py diff --git a/tests/proj/main/test_export_model.py b/tests/proj/main/test_export_model.py new file mode 100644 index 000000000..5e8a87aae --- /dev/null +++ b/tests/proj/main/test_export_model.py @@ -0,0 +1,33 @@ +import os +import pytest +from transformers import BertPreTrainedModel, BertTokenizer, RobertaForMaskedLM, RobertaTokenizer + +import jiant.utils.python.io as py_io +from jiant.proj.main.export_model import export_model + + +@pytest.mark.parametrize( + "model_type, model_class, tokenizer_class, hf_model_name", + [ + ("bert-base-cased", BertPreTrainedModel, BertTokenizer, "bert-base-cased"), + ( + "roberta-med-small-1M-1", + RobertaForMaskedLM, + RobertaTokenizer, + "nyu-mll/roberta-med-small-1M-1", + ), + ], +) +def test_export_model(tmp_path, model_type, model_class, tokenizer_class, hf_model_name): + export_model( + model_type=model_type, + output_base_path=tmp_path, + model_class=model_class, + tokenizer_class=tokenizer_class, + hf_model_name=hf_model_name, + ) + read_config = py_io.read_json(os.path.join(tmp_path, f"config.json")) + assert read_config["model_type"] == model_type + assert read_config["model_path"] == os.path.join(tmp_path, "model", f"{model_type}.p") + assert read_config["model_config_path"] == os.path.join(tmp_path, "model", f"{model_type}.json") + assert read_config["model_tokenizer_path"] == os.path.join(tmp_path, "tokenizer")