本项目是一个基于PyTorch从零实现的BERT模型及相关下游任务示例的代码仓库,同时也包含了BERT模型以及每个下有任务原理的详细讲解。
BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
更多关于Transformer内容的介绍可以参考文章 This post is all you need(层层剥开Transformer) ,近4万余字、50张图、3个实战示例(翻译 、分类 、对联生成 ),带你一网打尽Transformer!
- 1. BERT原理与NSL和MLM
- 2. 从零实现BERT网络模型 代码
- 3. 基于BERT预训练模型的中文文本分类任务 代码
- 4. 基于BERT预训练模型的英文文本蕴含(MNLI)任务 代码
- 5. 基于BERT预训练模型的英文多选项(SWAG)任务 代码
- 6. 基于BERT预训练模型的英文问答(SQuAD)任务 代码
- 7. 基于NSL和MLM任务从头训练BERT任务 代码
-
bert_base_chinese
目录中是BERT base中文预训练模型以及配置文件 -
bert_base_uncased_english
目录中是BERT base英文预训练模型以及配置文件模型下载地址:https://huggingface.co/bert-base-uncased/tree/main
注意:
config.json
中需要添加"pooler_type": "first_token_transform"
这个参数 -
data
目录中是各个下游任务所使用到的数据集SingleSentenceClassification
是今日头条的15分类中文数据集;PairSentenceClassification
是MNLI(The Multi-Genre Natural Language Inference Corpus, 多类型自然语言推理数据库)数据集;MultipeChoice
是SWAG问题选择数据集SQuAD
是斯坦福大学开源的问答数据集1.1版本WikiText
是维基百科英文语料用于模型预训练SongCi
是宋词语料用于中文模型预训练
-
model
目录中是各个模块的实现BasicBert
中是基础的BERT模型实现模块MyTransformer.py
是自注意力机制实现部分;BertEmbedding.py
是Input Embedding实现部分;BertConfig.py
用于导入开源的config.json
配置文件;Bert.py
是BERT模型的实现部分;
DownstreamTasks
目录是下游任务各个模块的实现BertForSentenceClassification.py
是单标签句子分类的实现部分;BertForMultipleChoice.py
是问题选择模型的实现部分;BertForQuestionAnswering.py
是问题回答(text span)模型的实现部分;BertForNSPAndMLM.py
是BERT模型预训练的两个任务实现部分;
-
Task
目录中是各个具体下游任务的训练和推理实现TaskForSingleSentenceClassification.py
是单标签单文本分类任务的训练和推理实现,可用于普通的文本分类任务;TaskForPairSentence.py
是文本对分类任务的训练和推理实现,可用于蕴含任务(例如MNLI数据集);TaskForMultipleChoice.py
是问答选择任务的训练和推理实现,可用于问答选择任务(例如SWAG数据集);TaskForSQuADQuestionAnswering.py
是问题回答任务的训练和推理实现,可用于问题问答任务(例如SQuAD数据集);TaskForPretraining.py
是BERT模型中MLM和NSP两个预训练任务的实现部分,可用于BERT模型预训练;
-
test
目录中是各个模块的测试案例 -
utils
是各个工具类的实现data_helpers.py
是各个下游任务的数据预处理及数据集构建模块;log_helper.py
是日志打印模块;creat_pretraining_data.py
是用于构造BERT预训练任务的数据集;
Python版本为3.6,其它相关包的版本如下:
torch==1.5.0
torchtext==0.6.0
torchvision==0.6.0
transformers==4.5.1
numpy==1.19.5
pandas==1.1.5
scikit-learn==0.24.0
tqdm==4.61.0
下载完成各个数据集以及相应的BERT预训练模型(如果为空),并放入对应的目录中。具体可以查看每个数据(data
)目录下的README.md
文件。
进入Tasks
目录,运行相关模型.
python TaskForSingleSentenceClassification.py
运行结果:
-- INFO: Epoch: 0, Batch[0/4186], Train loss :2.862, Train acc: 0.125
-- INFO: Epoch: 0, Batch[10/4186], Train loss :2.084, Train acc: 0.562
-- INFO: Epoch: 0, Batch[20/4186], Train loss :1.136, Train acc: 0.812
-- INFO: Epoch: 0, Batch[30/4186], Train loss :1.000, Train acc: 0.734
...
-- INFO: Epoch: 0, Batch[4180/4186], Train loss :0.418, Train acc: 0.875
-- INFO: Epoch: 0, Train loss: 0.481, Epoch time = 1123.244s
...
-- INFO: Epoch: 9, Batch[4180/4186], Train loss :0.102, Train acc: 0.984
-- INFO: Epoch: 9, Train loss: 0.100, Epoch time = 1130.071s
-- INFO: Accurcay on val 0.884
-- INFO: Accurcay on val 0.888
python TaskForPairSentenceClassification.py
运行结果:
-- INFO: Epoch: 0, Batch[0/17181], Train loss :1.082, Train acc: 0.438
-- INFO: Epoch: 0, Batch[10/17181], Train loss :1.104, Train acc: 0.438
-- INFO: Epoch: 0, Batch[20/17181], Train loss :1.129, Train acc: 0.250
-- INFO: Epoch: 0, Batch[30/17181], Train loss :1.063, Train acc: 0.375
...
-- INFO: Epoch: 0, Batch[17180/17181], Train loss :0.367, Train acc: 0.909
-- INFO: Epoch: 0, Train loss: 0.589, Epoch time = 2610.604s
...
-- INFO: Epoch: 9, Batch[0/17181], Train loss :0.064, Train acc: 1.000
-- INFO: Epoch: 9, Train loss: 0.142, Epoch time = 2542.781s
-- INFO: Accurcay on val 0.797
-- INFO: Accurcay on val 0.810
python TaskForMultipleChoice.py
运行结果:
[2021-11-11 21:32:50] - INFO: Epoch: 0, Batch[0/4597], Train loss :1.433, Train acc: 0.250
[2021-11-11 21:32:58] - INFO: Epoch: 0, Batch[10/4597], Train loss :1.277, Train acc: 0.438
[2021-11-11 21:33:01] - INFO: Epoch: 0, Batch[20/4597], Train loss :1.249, Train acc: 0.438
......
[2021-11-11 21:58:34] - INFO: Epoch: 0, Batch[4590/4597], Train loss :0.489, Train acc: 0.875
[2021-11-11 21:58:36] - INFO: Epoch: 0, Batch loss :0.786, Epoch time = 1546.173s
[2021-11-11 21:28:55] - INFO: Epoch: 0, Batch[0/4597], Train loss :1.433, Train acc: 0.250
[2021-11-11 21:30:52] - INFO: He is throwing darts at a wall. A woman, squats alongside flies side to side with his gun. ## False
[2021-11-11 21:30:52] - INFO: He is throwing darts at a wall. A woman, throws a dart at a dartboard. ## False
[2021-11-11 21:30:52] - INFO: He is throwing darts at a wall. A woman, collapses and falls to the floor. ## False
[2021-11-11 21:30:52] - INFO: He is throwing darts at a wall. A woman, is standing next to him. ## True
[2021-11-11 21:30:52] - INFO: Accuracy on val 0.794
python TaskForSQuADQuestionAnswering.py
运行结果:
[2022-01-02 14:42:17]缓存文件 ~/BertWithPretrained/data/SQuAD/dev-v1_128_384_64.pt 不存在,重新处理并缓存!
[2022-01-02 14:42:17] - DEBUG: <<<<<<<< 进入新的example >>>>>>>>>
[2022-01-02 14:42:17] - DEBUG: ## 正在预处理数据 utils.data_helpers is_training = False
[2022-01-02 14:42:17] - DEBUG: ## 问题 id: 56be5333acb8001400a5030d
[2022-01-02 14:42:17] - DEBUG: ## 原始问题 text: Which performers joined the headliner during the Super Bowl 50 halftime show?
[2022-01-02 14:42:17] - DEBUG: ## 原始描述 text: CBS broadcast Super Bowl 50 in the U.S., and charged an average of $5 million for a ....
[2022-01-02 14:42:17]- DEBUG: ## 上下文长度为:87, 剩余长度 rest_len 为 : 367
[2022-01-02 14:42:17] - DEBUG: ## input_tokens: ['[CLS]', 'which', 'performers', 'joined', 'the', 'headline', '##r', 'during', 'the', ...]
[2022-01-02 14:42:17] - DEBUG: ## input_ids:[101, 2029, 9567, 2587, 1996, 17653, 2099, 2076, 1996, 3565, 4605, 2753, 22589, 2265, 1029, 102, 6568, ....]
[2022-01-02 14:42:17] - DEBUG: ## segment ids:[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]
[2022-01-02 14:42:17] - DEBUG: ## orig_map:{16: 0, 17: 1, 18: 2, 19: 3, 20: 4, 21: 5, 22: 6, 23: 7, 24: 7, 25: 7, 26: 7, 27: 7, 28: 8, 29: 9, 30: 10,....}
[2022-01-02 14:42:17] - DEBUG: ======================
....
[2022-01-02 15:13:50] - INFO: Epoch:0, Batch[810/7387] Train loss: 0.998, Train acc: 0.708
[2022-01-02 15:13:55] - INFO: Epoch:0, Batch[820/7387] Train loss: 1.130, Train acc: 0.708
[2022-01-02 15:13:59] - INFO: Epoch:0, Batch[830/7387] Train loss: 1.960, Train acc: 0.375
[2022-01-02 15:14:04] - INFO: Epoch:0, Batch[840/7387] Train loss: 1.933, Train acc: 0.542
......
[2022-01-02 15:15:27] - INFO: ### Quesiotn: [CLS] when was the first university in switzerland founded..
[2022-01-02 15:15:27] - INFO: ## Predicted answer: 1460
[2022-01-02 15:15:27] - INFO: ## True answer: 1460
[2022-01-02 15:15:27] - INFO: ## True answer idx: (tensor(46, tensor(47))
[2022-01-02 15:15:27] - INFO: ### Quesiotn: [CLS] how many wards in plymouth elect two councillors?
[2022-01-02 15:15:27] - INFO: ## Predicted answer: 17 of which elect three .....
[2022-01-02 15:15:27] - INFO: ## True answer: three
[2022-01-02 15:15:27] - INFO: ## True answer idx: (tensor(25, tensor(25))
运行结束后,data/SQuAD
目录中会生成一个名为best_result.json
的预测文件,此时只需要切换到该目录下,并运行以下代码即可得到在dev-v1.1.json
的测试结果:
python evaluate-v1.1.py dev-v1.1.json best_result.json
"exact_match" : 80.879848628193, "f1": 88.338575234135