diff --git a/guides/tasks/supported_tasks.md b/guides/tasks/supported_tasks.md index ddfb39ae3..c7005e42f 100644 --- a/guides/tasks/supported_tasks.md +++ b/guides/tasks/supported_tasks.md @@ -4,6 +4,8 @@ | Name | `task_name` | `jiant` | Downloader | `jiant_task_name` | Misc | |---|---|:---:|:---:|---|---| +| MCTACO | mctaco | ✅ | | mctaco | | +| MCTest | mctest160 or mctest500 | ✅ | | mctest | | | [Argument Reasoning Comprehension](https://arxiv.org/abs/1708.01425) | arct | ✅ | | arct | [Github](https://github.com/UKPLab/argument-reasoning-comprehension-task) | | Abductive NLI | abductive_nli | ✅ | ✅ | abductive_nli | | | SuperGLUE Winogender Diagnostic | superglue_axg | ✅ | ✅ | superglue_axg | SuperGLUE | diff --git a/jiant/tasks/evaluate/core.py b/jiant/tasks/evaluate/core.py index 8294c6dd3..68e418004 100644 --- a/jiant/tasks/evaluate/core.py +++ b/jiant/tasks/evaluate/core.py @@ -73,7 +73,7 @@ def update(self, batch_logits, batch_loss, batch, batch_metadata): self.logits_list.append(batch_logits) batch_guid = batch_metadata.get("guid") if batch_guid is not None: - self.guid_list.append(batch_guid) + self.guid_list.extend(batch_guid) def get_guids(self): if self.guid_list: @@ -261,6 +261,48 @@ def compute_metrics_from_preds_and_labels(cls, preds, labels): return Metrics(major=acc, minor={"acc": acc}) +class MCTACOEvaluationScheme(BaseLogitsEvaluationScheme): + @classmethod + def get_preds_from_accumulator(self, task, accumulator): + logits = accumulator.get_accumulated() + pred = np.argmax(logits, axis=1) + guid = accumulator.guid_list + return guid, pred + + @classmethod + def compute_metrics_from_accumulator(self, task, accumulator, tokenizer, labels) -> Metrics: + guid, pred = self.get_preds_from_accumulator(task=task, accumulator=accumulator) + em_ls = [] + f1_ls = [] + label_pred_by_question = {} + + for one_guid, one_pred, one_label in zip(guid, pred, labels): + split, question_id, example_id = one_guid.split("-") + if question_id not in label_pred_by_question: + label_pred_by_question[question_id] = [], [] + label_pred_by_question[question_id][0].append(one_label) + label_pred_by_question[question_id][1].append(one_pred) + + em_ls = [ + float(group_label == group_pred) + for group_label, group_pred in label_pred_by_question.values() + ] + f1_ls = [ + f1_score(y_true=group_label, y_pred=group_pred) + for group_label, group_pred in label_pred_by_question.values() + ] + + em = sum(em_ls) / len(em_ls) + f1 = sum(f1_ls) / len(f1_ls) + minor = { + "em": em, + "f1": f1, + "f1_em": (f1 + em) / 2, + } + metrics = Metrics(major=minor["f1_em"], minor=minor,) + return metrics + + class MultiLabelAccAndF1EvaluationScheme(BaseLogitsEvaluationScheme): def get_labels_from_cache_and_examples(self, task, cache, examples): return get_multi_label_ids_from_cache(cache=cache) @@ -935,6 +977,8 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme: ), ): return SimpleAccuracyEvaluationScheme() + elif isinstance(task, tasks.MCTACOTask): + return MCTACOEvaluationScheme() elif isinstance(task, tasks.CCGTask): return CCGEvaluationScheme() elif isinstance(task, tasks.CommitmentBankTask): @@ -953,6 +997,7 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme: tasks.MutualTask, tasks.MutualPlusTask, tasks.SocialIQATask, + tasks.MCTestTask, ), ): return MultipleChoiceAccuracyEvaluationScheme() diff --git a/jiant/tasks/lib/mctaco.py b/jiant/tasks/lib/mctaco.py new file mode 100644 index 000000000..1ede641af --- /dev/null +++ b/jiant/tasks/lib/mctaco.py @@ -0,0 +1,116 @@ +import numpy as np +import torch +from dataclasses import dataclass +from typing import List + +from jiant.tasks.core import ( + BaseExample, + BaseTokenizedExample, + BaseDataRow, + BatchMixin, + Task, + TaskTypes, +) +from jiant.tasks.lib.templates.shared import double_sentence_featurize, labels_to_bimap +from jiant.utils.python.io import read_file_lines + + +@dataclass +class Example(BaseExample): + guid: str + sentence_question: str + answer: str + label: str + + def tokenize(self, tokenizer): + return TokenizedExample( + guid=self.guid, + sentence_question=tokenizer.tokenize(self.sentence_question), + answer=tokenizer.tokenize(self.answer), + label_id=MCTACOTask.LABEL_TO_ID[self.label], + ) + + +@dataclass +class TokenizedExample(BaseTokenizedExample): + guid: str + sentence_question: List + answer: List + label_id: int + + def featurize(self, tokenizer, feat_spec): + return double_sentence_featurize( + guid=self.guid, + input_tokens_a=self.sentence_question, + input_tokens_b=self.answer, + label_id=self.label_id, + tokenizer=tokenizer, + feat_spec=feat_spec, + data_row_class=DataRow, + ) + + +@dataclass +class DataRow(BaseDataRow): + guid: str + input_ids: np.ndarray + input_mask: np.ndarray + segment_ids: np.ndarray + label_id: int + tokens: list + + +@dataclass +class Batch(BatchMixin): + input_ids: torch.LongTensor + input_mask: torch.LongTensor + segment_ids: torch.LongTensor + label_id: torch.LongTensor + tokens: list + + +class MCTACOTask(Task): + Example = Example + TokenizedExample = TokenizedExample + DataRow = DataRow + Batch = Batch + + TASK_TYPE = TaskTypes.CLASSIFICATION + LABELS = ["yes", "no"] + LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) + + def get_train_examples(self): + return self._create_examples( + lines=read_file_lines(self.train_path, strip_lines=True), set_type="train" + ) + + def get_val_examples(self): + return self._create_examples( + lines=read_file_lines(self.val_path, strip_lines=True), set_type="val" + ) + + def get_test_examples(self): + return self._create_examples( + lines=read_file_lines(self.test_path, strip_lines=True), set_type="test" + ) + + @classmethod + def _create_examples(cls, lines, set_type): + # noinspection DuplicatedCode + examples = [] + last_question = "" + question_count = -1 + for (i, line) in enumerate(lines): + sentence, question, answer, label, category = line.split("\t") + if last_question != question: + question_count += 1 + last_question = question + examples.append( + Example( + guid="%s-q%s-%s" % (set_type, question_count, i), + sentence_question=sentence + question, + answer=answer, + label=label if set_type != "test" else cls.LABELS[-1], + ) + ) + return examples diff --git a/jiant/tasks/lib/mctest.py b/jiant/tasks/lib/mctest.py new file mode 100644 index 000000000..c6b05e978 --- /dev/null +++ b/jiant/tasks/lib/mctest.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass + +from jiant.tasks.lib.templates.shared import labels_to_bimap +from jiant.tasks.lib.templates import multiple_choice as mc_template +from jiant.utils.python.io import read_file_lines + + +@dataclass +class Example(mc_template.Example): + @property + def task(self): + return MCTestTask + + +@dataclass +class TokenizedExample(mc_template.TokenizedExample): + pass + + +@dataclass +class DataRow(mc_template.DataRow): + pass + + +@dataclass +class Batch(mc_template.Batch): + pass + + +class MCTestTask(mc_template.AbstractMultipleChoiceTask): + Example = Example + TokenizedExample = TokenizedExample + DataRow = DataRow + Batch = Batch + + CHOICE_KEYS = ["A", "B", "C", "D"] + CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) + NUM_CHOICES = len(CHOICE_KEYS) + + def get_train_examples(self): + return self._create_examples( + lines=read_file_lines(self.train_path, strip_lines=True), + ans_lines=read_file_lines(self.path_dict["train_ans"], strip_lines=True), + set_type="train", + ) + + def get_val_examples(self): + return self._create_examples( + lines=read_file_lines(self.val_path, strip_lines=True), + ans_lines=read_file_lines(self.path_dict["val_ans"], strip_lines=True), + set_type="val", + ) + + def get_test_examples(self): + return self._create_examples( + lines=read_file_lines(self.test_path, strip_lines=True), + ans_lines=None, + set_type="test", + ) + + @classmethod + def _create_examples(cls, lines, ans_lines, set_type): + examples = [] + if ans_lines is None: + ans_lines = ["\t".join([cls.CHOICE_KEYS[-1]] * 4) for line in lines] + for i, (line, ans) in enumerate(zip(lines, ans_lines)): + line = line.split("\t") + ans = ans.split("\t") + for j in range(4): + examples.append( + Example( + guid="%s-%s" % (set_type, i * 4 + j), + prompt=line[2].replace("\\newline", " ") + " " + line[3 + j * 5], + choice_list=line[4 + j * 5 : 8 + j * 5], + label=ans[j], + ) + ) + return examples diff --git a/jiant/tasks/retrieval.py b/jiant/tasks/retrieval.py index 3f42609aa..5b866f362 100644 --- a/jiant/tasks/retrieval.py +++ b/jiant/tasks/retrieval.py @@ -20,6 +20,8 @@ from jiant.tasks.lib.edge_probing.dpr import DprTask from jiant.tasks.lib.glue_diagnostics import GlueDiagnosticsTask from jiant.tasks.lib.hellaswag import HellaSwagTask +from jiant.tasks.lib.mctaco import MCTACOTask +from jiant.tasks.lib.mctest import MCTestTask from jiant.tasks.lib.mlm_simple import MLMSimpleTask from jiant.tasks.lib.mlm_premasked import MLMPremaskedTask from jiant.tasks.lib.mlm_pretokenized import MLMPretokenizedTask @@ -94,6 +96,8 @@ "dpr": DprTask, "glue_diagnostics": GlueDiagnosticsTask, "hellaswag": HellaSwagTask, + "mctaco": MCTACOTask, + "mctest": MCTestTask, "mlm_simple": MLMSimpleTask, "mlm_premasked": MLMPremaskedTask, "mlm_pretokenized": MLMPretokenizedTask,