Skip to content

Commit

Permalink
Adding Masked Language Modelling (#1030)
Browse files Browse the repository at this point in the history
* misc run scripts

* sbatch

* sweep scripts

* update

* qa

* update

* update

* update

* update

* update

* sb file

* moving update_metrics to outside scope of dataparallel

* fixing micro_avg calculation

* undo debugging

* Fixing tests, moving update_metrics out of other tasks

* remove extraneous change

* MLM task

* Added MLM task

* update

* fix multiple choice dataparallel forward

* update

* add _mask_id to transformers

* Update

* MLM update

* adding update_metrics abstraction

* delete update_metrics_ notation

* fixed wrong index problem

* removed unrelated files

* removed unrelated files

* removed unrelated files

* fix PEP8

* Fixed get_pretained_lm_head for BERT and ALBERT

* spelling check

* black formatting

* fixing tests

* bug fix

* Adding batch_size constraints to multi-GPU setting

* adding documentation

* adding batch size test

* black correct version

* Fixing batch size assertion

* generalize batch size assertion for more than 2 GPU setting

* reducing label loops in code

* fixing span forward

* Fixing span prediction forward for multi-GPU

* fix commonsenseQA forward

* MLM

* adding function documentation

* resolving nits, fixing seq_gen forward

* remove nit

* fixing batch_size assert and SpanPrediction task

* Remove debugging

* Fix batch size mismatch multi-GPU test

* Fix order of assert checking for batch size mismatch

* mlm training

* update

* sbatch

* update

* data parallel

* update data parallel stuffs

* using sequencelabel, using 1 paragraph per example

* update label mapping

* adding exmaples-porportion-mixing

* changing dataloader to work with wikitext103

* weight sampling

* add early stopping only onb one task

* commit

* Cleaning up code

* Removing unecessarily tracked git folders

* Removing unnecesary changes

* revert README

* revert README.md again

* Making more general for Transformer-based embedders

* torch.uint8 -> torch.bool

* Fixing indexing issues

* get rid of unecessary changes

* black cleanup

* update

* Prevent updating update_metrics twice in one step

* update

* update

* add base_roberta

* update

* reverting CCG edit added for debugging

* refactor defaults.conf

* black formatting

* merge

* removed SOP task and mlm_manual_scaling

* Fixing label namespace vocabulary creation, mergeing from master

* Deleting MLM weight

* black formatting

* Adding early_stopping_method to defaults.conf

* Fixing MLM with preprocessed wikitext103

* Deleting intermediate class hierarchy for MLM

* Correcting black

* LanguageModelingTask -> AutoregressiveModelingTask

* code style

* fixing MaskedLanguageModelTask

* Fixing typo

* Fixing label namespace

* extracting out masking portion

* Revert "extracting out masking portion"

This reverts commit f21165c.

* Code cleanup

* Adding tests for early_stpping_method

* Adding pretrain_stop_metric

* Reverting get_data_iter

* Reverting to get_data_iter

* Fixing get_pretrained_lm_head for all embedder types

* Extracting out MLM probability masking

* Move dynamic masking function to Task for easier testing

* Adding unit tests for MLM

* Adding change to MLM forward function to expose more intermediate steps for testing

* Fixing code style

* Adding more detailed instructions of how to generate Wikipedia data

* Adding rest of MLM data generation code

* Black style and remove comment

* black style

* updating repro code for MLM data

Co-authored-by: phu-pmh <[email protected]>
Co-authored-by: Haokun Liu <[email protected]>
Co-authored-by: pruksmhc <[email protected]>
Co-authored-by: DeepLearning VM <[email protected]>
  • Loading branch information
5 people committed Apr 10, 2020
1 parent c975afa commit c87a86b
Show file tree
Hide file tree
Showing 15 changed files with 545 additions and 39 deletions.
30 changes: 28 additions & 2 deletions jiant/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,33 @@ def load_model_for_target_train_run(args, ckpt_path, model, strict, task, cuda_d
return to_train


def get_pretrain_stop_metric(early_stopping_method, pretrain_tasks):
"""
Get stop_metric, which is used for early stopping.
Parameters
-------------------
early_stopping_method: str,
pretrain_tasks: List[Task]
Returns
-------------------
stop_metric: str
"""
if early_stopping_method != "auto":
pretrain_names = [task.name for task in pretrain_tasks]
if early_stopping_method in pretrain_names:
index = pretrain_names.index(early_stopping_method)
stop_metric = pretrain_tasks[index].val_metric
else:
raise ValueError("args.early_stopping_method must be either 'auto' or a task name")

else:
stop_metric = pretrain_tasks[0].val_metric if len(pretrain_tasks) == 1 else "macro_avg"
return stop_metric


def main(cl_arguments):
""" Train a model for multitask-training."""
cl_args = handle_arguments(cl_arguments)
Expand All @@ -551,7 +578,6 @@ def main(cl_arguments):
tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name)
log.info("\tFinished loading tasks in %.3fs", time.time() - start_time)
log.info("\t Tasks: {}".format([task.name for task in tasks]))

# Build model
log.info("Building model...")
start_time = time.time()
Expand All @@ -567,7 +593,7 @@ def main(cl_arguments):
if args.do_pretrain:
# Train on pretrain tasks
log.info("Training...")
stop_metric = pretrain_tasks[0].val_metric if len(pretrain_tasks) == 1 else "macro_avg"
stop_metric = get_pretrain_stop_metric(args.early_stopping_method, pretrain_tasks)
should_decrease = (
pretrain_tasks[0].val_metric_decreases if len(pretrain_tasks) == 1 else False
)
Expand Down
19 changes: 19 additions & 0 deletions jiant/config/base_mlm_roberta.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Base config file for mlm experiments wit roberta
include "defaults.conf"

early_stopping_method=auto // Early stopping method. Options: task_name to only do early stopping based
// on a specific task, 'auto': use the macro_avg

// Multi-task Training
weighting_method = proportional // Weighting method for task sampling, relative to the number of
// training examples in each task:
// Options: uniform, power_<power>, softmax_<temp>
// proportional, proportional_log_batch, and
// proportional_log_example (plus the less-useful inverse,
// inverse_log_example, and inverse_log_batch).
// Additionally, we include the T5 method of examples-proportional-mixing.
// See relevant source code for details.
scaling_method = uniform // Method for scaling loss:
// Options: uniform, max_power_<power>, max_proportional,
// max_proportional_log, max_inverse, max_inverse_log
// max_epoch_<E1_E2_..._En>
4 changes: 4 additions & 0 deletions jiant/config/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ max_epochs = -1 // If positive, maximum number of epochs (full pass over a task'
// especially if it's higher than one epoch's worth of steps, it's possible to
// significantly overshoot the intended number of epochs.

early_stopping_method=auto // Early stopping method. Options: task_name to only do early stopping based
// on a specific task, 'auto': use the macro_avg
patience = 5 // Patience in early stopping. Training will stop if performance does not improve at
// all in patience + 1 validations.
keep_all_checkpoints = 0 // If set, keep checkpoint files from every validation. Otherwise, keep
Expand All @@ -196,6 +198,8 @@ weighting_method = proportional // Weighting method for task sampling, relative
// proportional, proportional_log_batch, and
// proportional_log_example (plus the less-useful inverse,
// inverse_log_example, and inverse_log_batch).
// Additionally, we include the T5 method of examples_proportional_mixing.
// To use this, set weighting_method=examples_proportional_mixingK=104857
// See relevant source code for details.
scaling_method = uniform // Method for scaling loss:
// Options: uniform, max_power_<power>, max_proportional,
Expand Down
21 changes: 13 additions & 8 deletions jiant/huggingface_transformers_interface/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, args):
self._sep_id = None
self._pad_id = None
self._unk_id = None
self._mask_id = None

# If set, treat these special tokens as part of input segments other than A/B.
self._SEG_ID_CLS = None
Expand Down Expand Up @@ -270,6 +271,7 @@ def __init__(self, args):
self._cls_id = self.tokenizer.convert_tokens_to_ids("[CLS]")
self._pad_id = self.tokenizer.convert_tokens_to_ids("[PAD]")
self._unk_id = self.tokenizer.convert_tokens_to_ids("[UNK]")
self._mask_id = self.tokenizer.convert_tokens_to_ids("[MASK]")

self.parameter_setup(args)

Expand Down Expand Up @@ -305,7 +307,7 @@ def get_pretrained_lm_head(self):
)
lm_head = model_with_lm_head.cls
lm_head.predictions.decoder.weight = self.model.embeddings.word_embeddings.weight
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
return lm_head


class RobertaEmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand All @@ -327,6 +329,7 @@ def __init__(self, args):
self._cls_id = self.tokenizer.convert_tokens_to_ids("<s>")
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
self._mask_id = self.tokenizer.convert_tokens_to_ids("<mask>")

self.parameter_setup(args)

Expand Down Expand Up @@ -358,8 +361,8 @@ def get_pretrained_lm_head(self):
self.input_module, cache_dir=self.cache_dir
)
lm_head = model_with_lm_head.lm_head
lm_head.predictions.decoder.weight = self.model.embeddings.word_embeddings.weight
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
lm_head.decoder.weight = self.model.embeddings.word_embeddings.weight
return lm_head


class AlbertEmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand All @@ -381,6 +384,7 @@ def __init__(self, args):
self._cls_id = self.tokenizer.convert_tokens_to_ids("[CLS]")
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
self._mask_id = self.tokenizer.convert_tokens_to_ids("[MASK]")

self.parameter_setup(args)

Expand Down Expand Up @@ -416,7 +420,7 @@ def get_pretrained_lm_head(self):
)
lm_head = model_with_lm_head.predictions
lm_head.decoder.weight = self.model.embeddings.word_embeddings.weight
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
return lm_head


class XLNetEmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand All @@ -437,6 +441,7 @@ def __init__(self, args):
self._cls_id = self.tokenizer.convert_tokens_to_ids("<cls>")
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
self._mask_id = self.tokenizer.convert_tokens_to_ids("<mask>")

self.parameter_setup(args)

Expand Down Expand Up @@ -478,7 +483,7 @@ def get_pretrained_lm_head(self, args):
)
lm_head = model_with_lm_head.lm_loss
lm_head.weight = self.model.word_embedding.weight
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
return lm_head


class OpenAIGPTEmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand Down Expand Up @@ -541,7 +546,7 @@ def get_pretrained_lm_head(self, args):
)
lm_head = model_with_lm_head.lm_head
lm_head.weight = self.model.tokens_embed.weight[: lm_head.weight.size()[0]]
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
return lm_head


class GPT2EmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand Down Expand Up @@ -603,7 +608,7 @@ def get_pretrained_lm_head(self):
)
lm_head = model_with_lm_head.lm_head
lm_head.weight = self.model.wte.weight[: lm_head.weight.size()[0]]
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
return lm_head


class TransfoXLEmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand Down Expand Up @@ -724,4 +729,4 @@ def get_pretrained_lm_head(self):
)
lm_head = model_with_lm_head.pred_layer
lm_head.proj.weight = self.model.embeddings.weight
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
return lm_head
70 changes: 56 additions & 14 deletions jiant/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from jiant.modules.span_modules import SpanClassifierModule
from jiant.huggingface_transformers_interface import input_module_uses_transformers
from jiant.tasks.edge_probing import EdgeProbingTask
from jiant.tasks.lm import LanguageModelingTask
from jiant.tasks.lm import AutoregressiveLanguageModelingTask, MaskedLanguageModelingTask
from jiant.tasks.lm_parsing import LanguageModelingParsingTask
from jiant.tasks.qa import MultiRCTask, ReCoRDTask
from jiant.tasks.seq2seq import Seq2SeqTask
Expand Down Expand Up @@ -76,6 +76,7 @@
format_output,
uses_cuda,
)
from jiant.utils.data_loaders import get_tokenizer

# Elmo stuff
# Look in $ELMO_SRC_DIR (e.g. /usr/share/jsalt/elmo) or download from web
Expand Down Expand Up @@ -158,18 +159,22 @@ def build_sent_encoder(args, vocab, d_emb, tasks, embedder, cove_layer):
)
d_sent = args.d_word
log.info("Using PRPN sentence encoder!")
elif any(isinstance(task, LanguageModelingTask) for task in tasks) or args.sent_enc == "bilm":
elif (
any(isinstance(task, AutoregressiveLanguageModelingTask) for task in tasks)
or args.sent_enc == "bilm"
):
assert_for_log(args.sent_enc in ["rnn", "bilm"], "Only RNNLM supported!")
assert_for_log(
not (
args.input_module == "elmo"
or args.input_module.startswith("bert")
or args.input_module.startswith("xlnet")
),
f"Using input_module = {args.input_module} for language modeling is probably not a "
"good idea, since it allows the language model to use information from the right-hand "
"context.",
)
if any(isinstance(task, AutoregressiveLanguageModelingTask) for task in tasks):
assert_for_log(
not (
args.input_module == "elmo"
or args.input_module.startswith("bert")
or args.input_module.startswith("xlnet")
),
f"Using input_module = {args.input_module} for language modeling is probably not a "
"good idea, since it allows the language model to use information from the right-hand "
"context.",
)
bilm = BiLMEncoder(d_emb, args.d_hid, args.d_hid, args.n_layers_enc)
sent_encoder = SentenceEncoder(
vocab,
Expand Down Expand Up @@ -549,7 +554,10 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg
hid2voc = build_lm(task, d_sent, args)
setattr(model, "%s_hid2voc" % task.name, hid2voc)
setattr(model, "%s_mdl" % task.name, hid2voc)
elif isinstance(task, LanguageModelingTask):
elif isinstance(task, MaskedLanguageModelingTask):
module = build_mlm(model.sent_encoder._text_field_embedder)
setattr(model, "%s_mdl" % task.name, module)
elif isinstance(task, AutoregressiveLanguageModelingTask):
assert not input_module_uses_transformers(args.input_module), (
"our LM Task does not support transformers, if you need them, try to update",
"corresponding parts of the code. You may find get_pretrained_lm_head and",
Expand Down Expand Up @@ -746,6 +754,12 @@ def build_lm(task, d_inp, args):
return hid2voc


def build_mlm(embedder):
" Build MLM components "
lm_head = embedder.get_pretrained_lm_head()
return lm_head


def build_span_classifier(task, d_sent, task_params):
module = SpanClassifierModule(task, d_sent, task_params, num_spans=task.num_spans)
return module
Expand Down Expand Up @@ -853,7 +867,9 @@ def forward(self, task, batch, predict=False):
task, (PairClassificationTask, PairRegressionTask, PairOrdinalRegressionTask)
):
out = self._pair_sentence_forward(batch, task, predict)
elif isinstance(task, LanguageModelingTask):
elif isinstance(task, MaskedLanguageModelingTask):
out = self._masked_lm_forward(batch, task, predict)
elif isinstance(task, AutoregressiveLanguageModelingTask):
if isinstance(self.sent_encoder._phrase_layer, ONLSTMStack) or isinstance(
self.sent_encoder._phrase_layer, PRPN
):
Expand Down Expand Up @@ -1160,6 +1176,32 @@ def _lm_forward(self, batch, task, predict):
pass
return out

def _masked_lm_forward(self, batch, task, predict):
"""
We currently only support RoBERTa-style dynamic masking, with the exact
setup and parameters as RoBERTa.
"""
out = {}
tokenizer_name = self.sent_encoder._text_field_embedder.input_module
text_embedder = self.sent_encoder._text_field_embedder
vocab_size = text_embedder.model.embeddings.word_embeddings.num_embeddings
input_key = text_embedder.tokenizer_required
mask_idx = text_embedder._mask_id
b_size, seq_len = batch["targs"].size()
inputs = batch["input"][input_key]
labels = batch["targs"]
inputs, labels, _, _, _, _ = task.mlm_dynamic_masking(
inputs, labels, mask_idx, tokenizer_name, self.sent_encoder
)
batch["input"][input_key] = inputs
sent_embs, sent_mask = self.sent_encoder(batch["input"], task)
module = getattr(self, "%s_mdl" % task.name)
logits = module.forward(sent_embs)
out["logits"] = logits
out["loss"] = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1))
out["n_exs"] = format_output(b_size, self._cuda_device)
return out

def _mc_forward(self, batch, task, predict):
""" Forward for a multiple choice question answering task """
out = {}
Expand Down
6 changes: 2 additions & 4 deletions jiant/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from jiant.tasks import REGISTRY as TASKS_REGISTRY
from jiant.tasks.seq2seq import Seq2SeqTask
from jiant.tasks.tasks import SequenceGenerationTask, Task
from jiant.tasks.lm import MaskedLanguageModelingTask
from jiant.utils import config, serialize, utils, options
from jiant.utils.options import parse_task_list_arg

Expand Down Expand Up @@ -261,6 +262,7 @@ def _build_vocab(args: config.Params, tasks: List[Task], vocab_path: str):
for task in tasks: # add custom label namespaces
# TODO: surface more docs for add_task_label_vocab:
add_task_label_vocab(vocab, task)

if args.force_include_wsj_vocabulary:
# Add WSJ full vocabulary for PTB F1 parsing tasks.
add_wsj_vocab(vocab, args.data_dir)
Expand Down Expand Up @@ -661,10 +663,6 @@ def add_task_label_vocab(vocab, task):
return
log.info("\tTask '%s': adding vocab namespace '%s'", task.name, namespace)

if isinstance(task, SequenceGenerationTask):
for special in SPECIALS:
vocab.add_token_to_namespace(special, namespace)

for label in task.get_all_labels():
vocab.add_token_to_namespace(label, namespace)

Expand Down
Loading

0 comments on commit c87a86b

Please sign in to comment.