Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Masked Language Modelling #1030

Merged
merged 154 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from 136 commits
Commits
Show all changes
154 commits
Select commit Hold shift + click to select a range
430f942
misc run scripts
phu-pmh Oct 30, 2019
39603c3
sbatch
phu-pmh Oct 31, 2019
9b324f9
sweep scripts
phu-pmh Nov 4, 2019
d3cc769
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 5, 2019
00bc40c
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 9, 2019
4e297b1
update
phu-pmh Nov 9, 2019
b75d0f5
qa
phu-pmh Nov 10, 2019
1aadf48
update
phu-pmh Nov 10, 2019
8993b9e
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 10, 2019
a3f10e2
update
phu-pmh Nov 13, 2019
aa0d8b4
update
phu-pmh Nov 13, 2019
275d7a3
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 13, 2019
4b6b939
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 16, 2019
7252ea5
update
phu-pmh Nov 16, 2019
f0d9c56
update
phu-pmh Nov 20, 2019
00223c6
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Nov 27, 2019
b0a8ec3
sb file
phu-pmh Dec 12, 2019
c4d2601
moving update_metrics to outside scope of dataparallel
Jan 14, 2020
acb9d24
fixing micro_avg calculation
Jan 16, 2020
8bdec95
undo debugging
Jan 16, 2020
0d879b1
Merge branch 'master' of https://github.com/nyu-mll/jiant
phu-pmh Jan 17, 2020
4f0a169
Merge branch 'master' into fix_dataparallel_metric_calculation
Jan 17, 2020
5bb8389
Fixing tests, moving update_metrics out of other tasks
Jan 17, 2020
fb59ecc
Merge branch 'master' of https://github.com/nyu-mll/jiant into fix_da…
Jan 17, 2020
04dbbda
Merge branch 'fix_dataparallel_metric_calculation' of https://github.…
Jan 17, 2020
3ddf564
remove extraneous change
Jan 17, 2020
e588909
MLM task
phu-pmh Jan 21, 2020
dfa9fd9
Added MLM task
phu-pmh Jan 21, 2020
46182a9
update
phu-pmh Jan 24, 2020
607bcd2
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Jan 24, 2020
d1daf23
fix multiple choice dataparallel forward
Jan 25, 2020
9539302
Merge branch 'master' into fix_dataparallel_metric_calculation
Jan 25, 2020
fc5f026
update
phu-pmh Jan 27, 2020
ce7f5c2
add _mask_id to transformers
HaokunLiu Jan 28, 2020
ffc7354
Update
phu-pmh Jan 30, 2020
c50d75b
Merge branch 'master' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Jan 30, 2020
9649224
Merge branch 'master' into fix_dataparallel_metric_calculation
Jan 30, 2020
69a9364
MLM update
phu-pmh Jan 30, 2020
697d62c
Merge branch 'add-_mask_id-to-transformers' into MLM
HaokunLiu Jan 30, 2020
a4666da
adding update_metrics abstraction
Jan 30, 2020
fa13f6f
delete update_metrics_ notation
Jan 30, 2020
6b61e8b
fixed wrong index problem
phu-pmh Jan 30, 2020
3e10e3b
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Jan 30, 2020
afc0938
removed unrelated files
phu-pmh Jan 31, 2020
dcff7e7
removed unrelated files
phu-pmh Jan 31, 2020
1c1e6fb
removed unrelated files
phu-pmh Jan 31, 2020
f25ee99
fix PEP8
phu-pmh Jan 31, 2020
3f35212
Fixed get_pretained_lm_head for BERT and ALBERT
phu-pmh Jan 31, 2020
fc85270
spelling check
Feb 1, 2020
321bda8
black formatting
Feb 1, 2020
ae92b78
fixing tests
Feb 2, 2020
4f36878
bug fix
phu-pmh Feb 3, 2020
0467871
Adding batch_size constraints to multi-GPU setting
Feb 5, 2020
e3c5c79
adding documentation
Feb 5, 2020
6e96fd0
adding batch size test
Feb 5, 2020
845bf4f
Merge branch 'master' of https://github.com/nyu-mll/jiant into fix_da…
Feb 5, 2020
b41c268
black correct version
Feb 5, 2020
6f82412
Fixing batch size assertion
Feb 5, 2020
c749ea7
generalize batch size assertion for more than 2 GPU setting
Feb 5, 2020
73222a5
reducing label loops in code
Feb 6, 2020
fe39525
fixing span forward
Feb 8, 2020
745836d
Fixing span prediction forward for multi-GPU
invalid-email-address Feb 8, 2020
14caaab
fix commonsenseQA forward
invalid-email-address Feb 8, 2020
4271a7a
Merge branch 'master' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Feb 10, 2020
918c0df
MLM
phu-pmh Feb 10, 2020
5ed0691
adding function documentation
Feb 11, 2020
ffac8bf
Merge branch 'master' into fix_dataparallel_metric_calculation
Feb 11, 2020
fe86d96
resolving nits, fixing seq_gen forward
Feb 11, 2020
eee439f
Merge branch 'fix_dataparallel_metric_calculation' of https://github.…
Feb 11, 2020
b61fa7c
remove nit
Feb 11, 2020
55312e8
fixing batch_size assert and SpanPrediction task
Feb 12, 2020
7d165cf
Remove debugging
Feb 12, 2020
52f66c7
Fix batch size mismatch multi-GPU test
Feb 12, 2020
a0220f8
Fix order of assert checking for batch size mismatch
Feb 12, 2020
fe89674
mlm training
phu-pmh Feb 12, 2020
2218e5b
update
phu-pmh Feb 14, 2020
cd75715
Merge branch 'fix_dataparallel_metric_calculation' of https://github.…
phu-pmh Feb 14, 2020
58b2914
sbatch
phu-pmh Feb 16, 2020
052b1c0
update
phu-pmh Feb 17, 2020
b26927a
data parallel
phu-pmh Feb 17, 2020
cd4b5a6
update data parallel stuffs
phu-pmh Feb 19, 2020
0d6d691
update MLM
phu-pmh Feb 20, 2020
b3617fa
using sequencelabel, using 1 paragraph per example
Feb 23, 2020
0af6476
update label mapping
phu-pmh Feb 24, 2020
e9f863c
adding exmaples-porportion-mixing
Feb 24, 2020
89e44c5
changing dataloader to work with wikitext103
Feb 24, 2020
0752771
weight sampling
Feb 24, 2020
5482ac2
add early stopping only onb one task
Mar 5, 2020
6d85b27
commit
phu-pmh Mar 6, 2020
d67e195
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
phu-pmh Mar 6, 2020
921e717
Merge branch 'master' of https://github.com/nyu-mll/jiant into MLM
Mar 8, 2020
05d5750
Cleaning up code
Mar 8, 2020
ddcd357
Removing unecessarily tracked git folders
Mar 8, 2020
9e4e3a7
Removing unnecesary changes
Mar 8, 2020
b9b5f57
revert README
Mar 8, 2020
6b4c9d5
revert README.md again
Mar 8, 2020
35130ca
Making more general for Transformer-based embedders
Mar 8, 2020
20de779
torch.uint8 -> torch.bool
Mar 8, 2020
4020c81
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
Mar 8, 2020
09f5903
Fixing indexing issues
Mar 8, 2020
4f45826
get rid of unecessary changes
Mar 8, 2020
8ac8c70
black cleanup
Mar 8, 2020
6cee66e
update
phu-pmh Mar 8, 2020
3709696
Prevent updating update_metrics twice in one step
Mar 10, 2020
3fb4e3e
ALBERT SOP update
phu-pmh Mar 16, 2020
a56b7c7
update
phu-pmh Mar 18, 2020
b84da1d
update
phu-pmh Mar 18, 2020
2a19c2c
update
phu-pmh Mar 18, 2020
e7acb76
add base_roberta
Mar 20, 2020
b1ac702
update
phu-pmh Mar 20, 2020
c5fddf0
reverting CCG edit added for debugging
phu-pmh Mar 20, 2020
9774b61
refactor defaults.conf
phu-pmh Mar 20, 2020
194c2d4
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
Mar 22, 2020
4be35b3
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
Mar 22, 2020
429be9a
black formatting
Mar 22, 2020
a9555b1
merge
Mar 22, 2020
13002f6
removed SOP task and mlm_manual_scaling
phu-pmh Mar 22, 2020
9e6bc5d
Fixing label namespace vocabulary creation, mergeing from master
Mar 22, 2020
a0aad25
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
Mar 22, 2020
85db63e
Deleting MLM weight
Mar 22, 2020
4536433
Merge branch 'master' into MLM
Mar 22, 2020
85b081b
black formatting
Mar 22, 2020
eabe292
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
Mar 22, 2020
09caf0f
Adding early_stopping_method to defaults.conf
Mar 22, 2020
a7f8f16
Fixing MLM with preprocessed wikitext103
Mar 24, 2020
94c32ae
Deleting intermediate class hierarchy for MLM
Mar 24, 2020
74d474b
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
Mar 24, 2020
1d20684
Correcting black
Mar 24, 2020
12c0da1
LanguageModelingTask -> AutoregressiveModelingTask
Mar 24, 2020
960cf63
code style
Mar 24, 2020
f0d3b6d
fixing MaskedLanguageModelTask
Mar 25, 2020
cd2042e
Fixing typo
Mar 25, 2020
cf7612a
Fixing label namespace
Mar 27, 2020
f21165c
extracting out masking portion
Mar 28, 2020
1f25078
Revert "extracting out masking portion"
Apr 2, 2020
fec67cb
Code cleanup
Apr 2, 2020
c766706
Adding tests for early_stpping_method
Apr 2, 2020
1a5c06b
Merge branch 'master' of https://github.com/nyu-mll/jiant into MLM
Apr 2, 2020
5c3ff7b
Adding pretrain_stop_metric
Apr 3, 2020
8ca1eba
Reverting get_data_iter
Apr 3, 2020
9b377ab
Reverting to get_data_iter
Apr 6, 2020
bf841e9
Fixing get_pretrained_lm_head for all embedder types
Apr 6, 2020
2349464
Extracting out MLM probability masking
Apr 7, 2020
cf223a4
Merge branch 'MLM' of https://github.com/nyu-mll/jiant into MLM
Apr 7, 2020
a3465c1
Move dynamic masking function to Task for easier testing
Apr 8, 2020
0f5b849
Adding unit tests for MLM
Apr 8, 2020
fb9ce83
Adding change to MLM forward function to expose more intermediate ste…
Apr 8, 2020
a59c762
Fixing code style
Apr 9, 2020
e9eb5f0
Adding more detailed instructions of how to generate Wikipedia data
Apr 10, 2020
1a76df0
Adding rest of MLM data generation code
Apr 10, 2020
34c924b
Black style and remove comment
Apr 10, 2020
da5fe19
black style
Apr 10, 2020
9446cb7
updating repro code for MLM data
phu-pmh Apr 10, 2020
3f6eb92
updating repro code for MLM data
phu-pmh Apr 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions jiant/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,6 @@ def main(cl_arguments):
tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name)
log.info("\tFinished loading tasks in %.3fs", time.time() - start_time)
log.info("\t Tasks: {}".format([task.name for task in tasks]))

# Build model
log.info("Building model...")
start_time = time.time()
Expand All @@ -567,7 +566,16 @@ def main(cl_arguments):
if args.do_pretrain:
# Train on pretrain tasks
log.info("Training...")
stop_metric = pretrain_tasks[0].val_metric if len(pretrain_tasks) == 1 else "macro_avg"
if args.early_stopping_method != "auto":
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to test this logic? Could we run this logic with a variety of task lists and early stopping config options, and make sure we always get the expected method back?

Errors here would be pretty hard to catch/debug. Seems worth doing even if it means that we'd have to pull this out as a method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be reasonable to test. Adding that to to-do list.

pretrain_names = [task.name for task in pretrain_tasks]
if args.early_stopping_method in pretrain_names:
index = pretrain_names.index(args.early_stopping_method)
stop_metric = pretrain_tasks[index].val_metric
else:
raise ValueError("args.early_stopping_method must be either 'auto' or a task name")

else:
stop_metric = pretrain_tasks[0].val_metric if len(pretrain_tasks) == 1 else "macro_avg"
should_decrease = (
pretrain_tasks[0].val_metric_decreases if len(pretrain_tasks) == 1 else False
)
Expand Down
19 changes: 19 additions & 0 deletions jiant/config/base_mlm_roberta.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Base config file for mlm experiments wit roberta
include "defaults.conf"

early_stopping_method=auto // Early stopping method. Options: task_name to only do early stopping based
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
// on a specific task, 'auto': use the macro_avg

// Multi-task Training
weighting_method = proportional // Weighting method for task sampling, relative to the number of
// training examples in each task:
// Options: uniform, power_<power>, softmax_<temp>
// proportional, proportional_log_batch, and
// proportional_log_example (plus the less-useful inverse,
// inverse_log_example, and inverse_log_batch).
// Additionally, we include the T5 method of examples-proportional-mixing.
// See relevant source code for details.
scaling_method = uniform // Method for scaling loss:
// Options: uniform, max_power_<power>, max_proportional,
// max_proportional_log, max_inverse, max_inverse_log
// max_epoch_<E1_E2_..._En>
4 changes: 4 additions & 0 deletions jiant/config/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ max_epochs = -1 // If positive, maximum number of epochs (full pass over a task'
// especially if it's higher than one epoch's worth of steps, it's possible to
// significantly overshoot the intended number of epochs.

early_stopping_method=auto // Early stopping method. Options: task_name to only do early stopping based
// on a specific task, 'auto': use the macro_avg
patience = 5 // Patience in early stopping. Training will stop if performance does not improve at
// all in patience + 1 validations.
keep_all_checkpoints = 0 // If set, keep checkpoint files from every validation. Otherwise, keep
Expand All @@ -196,6 +198,8 @@ weighting_method = proportional // Weighting method for task sampling, relative
// proportional, proportional_log_batch, and
// proportional_log_example (plus the less-useful inverse,
// inverse_log_example, and inverse_log_batch).
// Additionally, we include the T5 method of examples_proportional_mixing.
// To use this, set weighting_method=examples_proportional_mixingK=104857
// See relevant source code for details.
scaling_method = uniform // Method for scaling loss:
// Options: uniform, max_power_<power>, max_proportional,
Expand Down
15 changes: 10 additions & 5 deletions jiant/huggingface_transformers_interface/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, args):
self._sep_id = None
self._pad_id = None
self._unk_id = None
self._mask_id = None

# If set, treat these special tokens as part of input segments other than A/B.
self._SEG_ID_CLS = None
Expand Down Expand Up @@ -270,6 +271,7 @@ def __init__(self, args):
self._cls_id = self.tokenizer.convert_tokens_to_ids("[CLS]")
self._pad_id = self.tokenizer.convert_tokens_to_ids("[PAD]")
self._unk_id = self.tokenizer.convert_tokens_to_ids("[UNK]")
self._mask_id = self.tokenizer.convert_tokens_to_ids("[MASK]")

self.parameter_setup(args)

Expand Down Expand Up @@ -327,6 +329,7 @@ def __init__(self, args):
self._cls_id = self.tokenizer.convert_tokens_to_ids("<s>")
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
self._mask_id = self.tokenizer.convert_tokens_to_ids("<mask>")

self.parameter_setup(args)

Expand Down Expand Up @@ -358,8 +361,8 @@ def get_pretrained_lm_head(self):
self.input_module, cache_dir=self.cache_dir
)
lm_head = model_with_lm_head.lm_head
lm_head.predictions.decoder.weight = self.model.embeddings.word_embeddings.weight
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
lm_head.decoder.weight = self.model.embeddings.word_embeddings.weight
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
return lm_head


class AlbertEmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand All @@ -381,6 +384,7 @@ def __init__(self, args):
self._cls_id = self.tokenizer.convert_tokens_to_ids("[CLS]")
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
self._mask_id = self.tokenizer.convert_tokens_to_ids("[MASK]")

self.parameter_setup(args)

Expand Down Expand Up @@ -437,6 +441,7 @@ def __init__(self, args):
self._cls_id = self.tokenizer.convert_tokens_to_ids("<cls>")
self._pad_id = self.tokenizer.convert_tokens_to_ids("<pad>")
self._unk_id = self.tokenizer.convert_tokens_to_ids("<unk>")
self._mask_id = self.tokenizer.convert_tokens_to_ids("<mask>")

self.parameter_setup(args)

Expand Down Expand Up @@ -478,7 +483,7 @@ def get_pretrained_lm_head(self, args):
)
lm_head = model_with_lm_head.lm_loss
lm_head.weight = self.model.word_embedding.weight
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
return lm_head


class OpenAIGPTEmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand Down Expand Up @@ -541,7 +546,7 @@ def get_pretrained_lm_head(self, args):
)
lm_head = model_with_lm_head.lm_head
lm_head.weight = self.model.tokens_embed.weight[: lm_head.weight.size()[0]]
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
return lm_head


class GPT2EmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand Down Expand Up @@ -603,7 +608,7 @@ def get_pretrained_lm_head(self):
)
lm_head = model_with_lm_head.lm_head
lm_head.weight = self.model.wte.weight[: lm_head.weight.size()[0]]
return nn.Sequential(lm_head, nn.LogSoftmax(dim=-1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's worth mentioning these cuts in the PR description.

return lm_head


class TransfoXLEmbedderModule(HuggingfaceTransformersEmbedderModule):
Expand Down
105 changes: 91 additions & 14 deletions jiant/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from jiant.modules.span_modules import SpanClassifierModule
from jiant.huggingface_transformers_interface import input_module_uses_transformers
from jiant.tasks.edge_probing import EdgeProbingTask
from jiant.tasks.lm import LanguageModelingTask
from jiant.tasks.lm import AutoregressiveLanguageModelingTask, MaskedLanguageModelingTask
from jiant.tasks.lm_parsing import LanguageModelingParsingTask
from jiant.tasks.qa import MultiRCTask, ReCoRDTask
from jiant.tasks.seq2seq import Seq2SeqTask
Expand Down Expand Up @@ -76,6 +76,7 @@
format_output,
uses_cuda,
)
from jiant.utils.data_loaders import get_tokenizer

# Elmo stuff
# Look in $ELMO_SRC_DIR (e.g. /usr/share/jsalt/elmo) or download from web
Expand Down Expand Up @@ -158,18 +159,22 @@ def build_sent_encoder(args, vocab, d_emb, tasks, embedder, cove_layer):
)
d_sent = args.d_word
log.info("Using PRPN sentence encoder!")
elif any(isinstance(task, LanguageModelingTask) for task in tasks) or args.sent_enc == "bilm":
elif (
any(isinstance(task, AutoregressiveLanguageModelingTask) for task in tasks)
or args.sent_enc == "bilm"
):
assert_for_log(args.sent_enc in ["rnn", "bilm"], "Only RNNLM supported!")
assert_for_log(
not (
args.input_module == "elmo"
or args.input_module.startswith("bert")
or args.input_module.startswith("xlnet")
),
f"Using input_module = {args.input_module} for language modeling is probably not a "
"good idea, since it allows the language model to use information from the right-hand "
"context.",
)
if any(isinstance(task, AutoregressiveLanguageModelingTask) for task in tasks):
assert_for_log(
not (
args.input_module == "elmo"
or args.input_module.startswith("bert")
or args.input_module.startswith("xlnet")
),
f"Using input_module = {args.input_module} for language modeling is probably not a "
"good idea, since it allows the language model to use information from the right-hand "
"context.",
)
bilm = BiLMEncoder(d_emb, args.d_hid, args.d_hid, args.n_layers_enc)
sent_encoder = SentenceEncoder(
vocab,
Expand Down Expand Up @@ -549,7 +554,10 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg
hid2voc = build_lm(task, d_sent, args)
setattr(model, "%s_hid2voc" % task.name, hid2voc)
setattr(model, "%s_mdl" % task.name, hid2voc)
elif isinstance(task, LanguageModelingTask):
elif isinstance(task, MaskedLanguageModelingTask):
module = build_mlm(model.sent_encoder._text_field_embedder)
setattr(model, "%s_mdl" % task.name, module)
elif isinstance(task, AutoregressiveLanguageModelingTask):
assert not input_module_uses_transformers(args.input_module), (
"our LM Task does not support transformers, if you need them, try to update",
"corresponding parts of the code. You may find get_pretrained_lm_head and",
Expand Down Expand Up @@ -746,6 +754,12 @@ def build_lm(task, d_inp, args):
return hid2voc


def build_mlm(embedder):
" Build MLM components "
lm_head = embedder.get_pretrained_lm_head()
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
return lm_head


def build_span_classifier(task, d_sent, task_params):
module = SpanClassifierModule(task, d_sent, task_params, num_spans=task.num_spans)
return module
Expand Down Expand Up @@ -853,7 +867,9 @@ def forward(self, task, batch, predict=False):
task, (PairClassificationTask, PairRegressionTask, PairOrdinalRegressionTask)
):
out = self._pair_sentence_forward(batch, task, predict)
elif isinstance(task, LanguageModelingTask):
elif isinstance(task, MaskedLanguageModelingTask):
out = self._masked_lm_forward(batch, task, predict)
elif isinstance(task, AutoregressiveLanguageModelingTask):
if isinstance(self.sent_encoder._phrase_layer, ONLSTMStack) or isinstance(
self.sent_encoder._phrase_layer, PRPN
):
Expand Down Expand Up @@ -1160,6 +1176,67 @@ def _lm_forward(self, batch, task, predict):
pass
return out

def _masked_lm_forward(self, batch, task, predict):
"""
We currently only support RoBERTa-style dynamic masking, with the exact
setup and parameters as RoBERTa.
"""
mlm_probability = 0.15
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
out = {}
sent_encoder = self.sent_encoder
tokenizer_name = self.sent_encoder._text_field_embedder.input_module
vocab_size = (
self.sent_encoder._text_field_embedder.model.embeddings.word_embeddings.num_embeddings
)
tokenizer = get_tokenizer(tokenizer_name)
input_key = self.sent_encoder._text_field_embedder.tokenizer_required
mask_idx = self.sent_encoder._text_field_embedder._mask_id
b_size, seq_len = batch["targs"].size()
inputs = batch["input"][input_key]
labels = batch["targs"]
# Masking code from https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py
probability_matrix = torch.full(labels.shape, mlm_probability, device=inputs.device)
padding_mask = labels.eq(0)
probability_matrix.masked_fill_(padding_mask, value=0.0)

masked_indices = torch.bernoulli(probability_matrix).to(
device=inputs.device, dtype=torch.uint8
)
tokenizer_name = self.sent_encoder._text_field_embedder.tokenizer_required
labels, _ = self.sent_encoder._text_field_embedder.correct_sent_indexing(
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

@pyeres pyeres Apr 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the labels will be modified (as a result of correct_sent_indexing()). It looks like the inputs aren't getting the same adjustment here. Is that intentional/correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{tokenizer_name: labels}
)
# We only compute loss on masked tokens
# nn.CrossEntropy ignores the idices with value = -100 by default.
# Therefore, we replace non-masked indices with -100 so that they get ignored
# in loss computation.
labels[~masked_indices] = -100

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
bernoulli_mask = torch.bernoulli(torch.full(labels.shape, 0.8)).to(
device=inputs.device, dtype=torch.uint8
)
indices_replaced = bernoulli_mask & masked_indices
inputs[indices_replaced] = mask_idx

# 10% of the time, we replace masked input tokens with random word
bernoulli_mask = torch.bernoulli(torch.full(labels.shape, 0.5)).to(
device=inputs.device, dtype=torch.uint8
)
indices_random = bernoulli_mask & masked_indices & ~indices_replaced
random_words = torch.randint(
len(tokenizer), labels.shape, dtype=torch.long, device=inputs.device
)
inputs[indices_random] = random_words[indices_random]
batch["input"][input_key] = inputs
sent_embs, sent_mask = self.sent_encoder(batch["input"], task)
module = getattr(self, "%s_mdl" % task.name)
logits = module.forward(sent_embs)
out["logits"] = logits
out["loss"] = F.cross_entropy(logits.view(-1, vocab_size), labels.view(-1))
out["n_exs"] = format_output(b_size, self._cuda_device)
return out

def _mc_forward(self, batch, task, predict):
""" Forward for a multiple choice question answering task """
out = {}
Expand Down
6 changes: 2 additions & 4 deletions jiant/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from jiant.tasks import REGISTRY as TASKS_REGISTRY
from jiant.tasks.seq2seq import Seq2SeqTask
from jiant.tasks.tasks import SequenceGenerationTask, Task
from jiant.tasks.lm import MaskedLanguageModelingTask
from jiant.utils import config, serialize, utils, options
from jiant.utils.options import parse_task_list_arg

Expand Down Expand Up @@ -261,6 +262,7 @@ def _build_vocab(args: config.Params, tasks: List[Task], vocab_path: str):
for task in tasks: # add custom label namespaces
# TODO: surface more docs for add_task_label_vocab:
add_task_label_vocab(vocab, task)

pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
if args.force_include_wsj_vocabulary:
# Add WSJ full vocabulary for PTB F1 parsing tasks.
add_wsj_vocab(vocab, args.data_dir)
Expand Down Expand Up @@ -661,10 +663,6 @@ def add_task_label_vocab(vocab, task):
return
log.info("\tTask '%s': adding vocab namespace '%s'", task.name, namespace)

if isinstance(task, SequenceGenerationTask):
for special in SPECIALS:
vocab.add_token_to_namespace(special, namespace)

for label in task.get_all_labels():
vocab.add_token_to_namespace(label, namespace)

Expand Down
Loading