Skip to content

Commit

Permalink
update conversion scripts and __main__
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Jul 16, 2019
1 parent 352e3ff commit 1b35d05
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 20 deletions.
28 changes: 21 additions & 7 deletions pytorch_transformers/__main__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# coding: utf8
def main():
import sys
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet"]:
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]:
print(
"Should be used as one of: \n"
">> `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
">> `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
">> `pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
">> `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or \n"
">> `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n"
">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n"
">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n"
">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n"
">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n"
">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT")
else:
if sys.argv[1] == "bert":
try:
Expand Down Expand Up @@ -86,7 +87,7 @@ def main():
else:
TF_CONFIG = ""
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
else:
elif sys.argv[1] == "xlnet":
try:
from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
except ImportError:
Expand All @@ -104,11 +105,24 @@ def main():
PYTORCH_DUMP_OUTPUT = sys.argv[4]
if len(sys.argv) == 6:
FINETUNING_TASK = sys.argv[5]
else:
FINETUNING_TASK = None

convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
TF_CONFIG,
PYTORCH_DUMP_OUTPUT,
FINETUNING_TASK)
elif sys.argv[1] == "xlm":
from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch

if len(sys.argv) != 4:
# pylint: disable=line-too-long
print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`")
else:
XLM_CHECKPOINT_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]

convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT)

if __name__ == '__main__':
main()
5 changes: 4 additions & 1 deletion pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
GPT2Model,
load_tf_weights_in_gpt2)

import logging
logging.basicConfig(level=logging.INFO)


def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path):
# Construct model
Expand All @@ -36,7 +39,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
model = GPT2Model(config)

# Load weights from numpy
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path)
load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)

# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
Expand Down
5 changes: 4 additions & 1 deletion pytorch_transformers/convert_openai_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
OpenAIGPTModel,
load_tf_weights_in_openai_gpt)

import logging
logging.basicConfig(level=logging.INFO)


def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Construct model
Expand All @@ -36,7 +39,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
model = OpenAIGPTModel(config)

# Load weights from numpy
load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path)

# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
Expand Down
9 changes: 4 additions & 5 deletions pytorch_transformers/convert_tf_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,22 @@
from __future__ import division
from __future__ import print_function

import os
import re
import argparse
import tensorflow as tf
import torch
import numpy as np

from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert

import logging
logging.basicConfig(level=logging.INFO)

def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)

# Load weights from tf checkpoint
load_tf_weights_in_bert(model, tf_checkpoint_path)
load_tf_weights_in_bert(model, config, tf_checkpoint_path)

# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
else:
import pickle

import logging
logging.basicConfig(level=logging.INFO)

# We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils.Vocab = data_utils.TransfoXLTokenizer
Expand Down
3 changes: 2 additions & 1 deletion pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
import numpy

from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME
from pytorch_transformers.modeling_xlm import (XLMConfig, XLMModel)
from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES

import logging
logging.basicConfig(level=logging.INFO)

def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path):
# Load checkpoint
Expand Down
9 changes: 7 additions & 2 deletions pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"wnli": 2,
}

import logging
logging.basicConfig(level=logging.INFO)

def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None):
# Initialise PyTorch model
Expand All @@ -48,14 +50,17 @@ def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, py
finetuning_task = finetuning_task.lower() if finetuning_task is not None else ""
if finetuning_task in GLUE_TASKS_NUM_LABELS:
print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config)))
model = XLNetForSequenceClassification(config, num_labels=GLUE_TASKS_NUM_LABELS[finetuning_task])
config.finetuning_task = finetuning_task
config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task]
model = XLNetForSequenceClassification(config)
elif 'squad' in finetuning_task:
config.finetuning_task = finetuning_task
model = XLNetForQuestionAnswering(config)
else:
model = XLNetLMHeadModel(config)

# Load weights from tf checkpoint
load_tf_weights_in_xlnet(model, config, tf_checkpoint_path, finetuning_task)
load_tf_weights_in_xlnet(model, config, tf_checkpoint_path)

# Save pytorch-model
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
Expand Down
2 changes: 2 additions & 0 deletions pytorch_transformers/modeling_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@
logger = logging.getLogger(__name__)

XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin",
}
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json",
}

Expand Down
2 changes: 1 addition & 1 deletion pytorch_transformers/tokenization_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'transfo-xl-wt103': 512,
'transfo-xl-wt103': None,
}

PRETRAINED_CORPUS_ARCHIVE_MAP = {
Expand Down
3 changes: 2 additions & 1 deletion pytorch_transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def _from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs
# if we're using a pretrained model, ensure the tokenizer
# wont index sequences longer than the number of positional embeddings
max_len = cls.max_model_input_sizes[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
if max_len is not None and isinstance(max_len, (int, float)):
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)

# Merge resolved_vocab_files arguments in kwargs.
added_tokens_file = resolved_vocab_files.pop('added_tokens_file', None)
Expand Down
4 changes: 3 additions & 1 deletion pytorch_transformers/tokenization_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
}
}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlnet-large-cased': 512,
'xlnet-base-cased': None,
'xlnet-large-cased': None,
}

SPIECE_UNDERLINE = u'▁'
Expand Down

0 comments on commit 1b35d05

Please sign in to comment.