Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dialogue dataset #6654

Merged
merged 32 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
38fa1b6
chatbot interface
yidong72 Apr 4, 2023
0069f3a
latest gradio
yidong72 Apr 4, 2023
87815b4
default greedy
yidong72 Apr 4, 2023
254007f
better chatbot
yidong72 Apr 5, 2023
9599032
Merge branch 'main' into chatbot_ui
yidong72 Apr 6, 2023
37f1905
handle preamble
yidong72 Apr 10, 2023
4770dac
Merge branch 'main' into chatbot_ui
yidong72 Apr 11, 2023
91993ef
added chatbot training capablity
yidong72 Apr 12, 2023
3f7cf33
added chatbot ui
yidong72 Apr 15, 2023
864faec
remove debug code
yidong72 Apr 15, 2023
757122e
default human
yidong72 Apr 16, 2023
be4b7eb
use special token for roles
yidong72 Apr 18, 2023
2ddbaa1
special tokens
yidong72 Apr 19, 2023
1c9260f
fix name
yidong72 Apr 25, 2023
6361295
new chat dataset
yidong72 May 4, 2023
93ecc33
fix the system token
yidong72 May 5, 2023
67ff1c8
upgrade gradio
yidong72 May 6, 2023
8020920
save the chat history
yidong72 May 8, 2023
57a97d6
Merge branch 'chatbot_ds' of github.com:NVIDIA/NeMo into chatbot_ds
yidong72 May 8, 2023
d3f91ee
update ui
May 9, 2023
48ef830
update chat interface
yidong72 May 9, 2023
21d476b
handles canonical form
yidong72 May 11, 2023
c550faf
Merge branch 'chatbot_ds' of github.com:NVIDIA/NeMo into chatbot_ds
yidong72 May 11, 2023
9bb735e
new sft chatbot
yidong72 May 15, 2023
6e47a60
Merge branch 'main' into chatbot_ds
yidong72 May 15, 2023
61b5094
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2023
2850396
change format
yidong72 May 15, 2023
91a968c
Merge branch 'chatbot_ds' of github.com:NVIDIA/NeMo into chatbot_ds
yidong72 May 15, 2023
33d1767
check extra_id in the tokenizer
yidong72 May 15, 2023
693b4a4
added vocab property check
yidong72 May 15, 2023
da64c7b
added missing file
yidong72 May 15, 2023
eac1407
Merge branch 'main' into chatbot_ds
MaximumEntropy May 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ web_server: False # whether launch the web inference server
share: False # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
web_port: 9889 # the port number of the web server
web_port: 9889 # the port number of the web server
chat: False # use the chat interface
8 changes: 6 additions & 2 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.modules.common.megatron_web_server import get_demo
from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo
from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
Expand Down Expand Up @@ -277,9 +277,13 @@ def main(cfg) -> None:
if cfg.server:
if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0:
if cfg.web_server:
if cfg.chat:
web_ui = get_chatbot_demo
else:
web_ui = get_demo
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=get_demo,
target=web_ui,
daemon=True,
args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ model:
ffn_dropout: 0.0

data:
chat: False # whether use chatbot data or not
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

import torch

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.utils import logging

__all__ = ['GPTSFTChatDataset']

IGNORE_INDEX = -100
END_SIGNAL = "\n"
END_NAME_SIGNAL = "\n"

SYSTEM_TOKEN = "<extra_id_0>System\n"
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
TURN_TOKEN = "<extra_id_1>"

GUARD_RAIL_INSTRUCTION = {
"TEXT_TO_CANONICAL_FORM": "Given a dialogue, for each turn you need to generate a short summary called a canonical form. Generate the canonical form for the last turn in the dialogue.",
"CANONICAL_FORM_TO_TEXT": "Given a dialogue, for each turn we also have a short summary called a canonical form. Generate the canonical form given the last turn message and canonical form. Then generate the message.",
}


def _mask_targets(target, tokenized_lens, speakers, header_len, s_ids, tokenizer, mask_role):
cur_idx = header_len
tgt_len = target.shape[0]
for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)):
# note, sentence piece will add extra empty token in front. s_id has that extra token too
skip_name_len = len(tokenizer.text_to_ids(TURN_TOKEN + speaker + END_NAME_SIGNAL))
if cur_idx >= tgt_len:
break
elif cur_idx + tokenized_len < tgt_len:
# Check whether the mask is applied to the correct position, the first token is turn token: <extra_id_1>
# s_id[2:] skips the artifact empty token and the turn token
# target[cur_idx + 1:cur_idx + tokenized_len] skip the turn token
if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[2:]):
logging.warning("a sentence mismatches the corresponding piece " "in the conversation")
if i == 0:
# mask the first turn completely to provide at least one turn as context
target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX
elif speaker == mask_role:
# leave the first human tag unmasked
target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX
else:
# mask up to the name end, need to remove one as skip name has an extra artifact empty token
target[cur_idx : cur_idx + skip_name_len - 1] = IGNORE_INDEX
cur_idx += tokenized_len


def cannonical_form_formater(cannoical_form):
return f'<extra_id_2>{cannoical_form}\n'


def _add_speaker_and_signal(header, source, mask_role, gtype):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = ""
conversation = header
for i, sentence in enumerate(source):
sentence_from = sentence["from"]
role_token = TURN_TOKEN
if gtype is None:
sentence["value"] = (
BEGIN_SIGNAL + role_token + sentence_from + END_NAME_SIGNAL + sentence["value"] + END_SIGNAL
)
elif gtype == "TEXT_TO_CANONICAL_FORM":
sentence["value"] = (
BEGIN_SIGNAL
+ role_token
+ sentence_from
+ END_NAME_SIGNAL
+ sentence["value"]
+ END_SIGNAL
+ cannonical_form_formater(sentence['canonical_form'])
)
elif gtype == "CANONICAL_FORM_TO_TEXT":
sentence["value"] = (
BEGIN_SIGNAL
+ role_token
+ sentence_from
+ END_NAME_SIGNAL
+ cannonical_form_formater(sentence['canonical_form'])
+ sentence["value"]
+ END_SIGNAL
)
else:
raise ValueError(f"source type {gtype} not supported")
conversation += sentence["value"]
# if the last turn is not masked, add next token start token to the end, which will be included for loss calculation
if sentence_from != mask_role and i == len(source) - 1:
conversation += TURN_TOKEN
return conversation


def preprocess(
source: dict, tokenizer: TokenizerSpec,
):
"""
Given a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
canonical_type = None
if 'type' in source:
canonical_type = source['type']
assert canonical_type in GUARD_RAIL_INSTRUCTION, f"source type {canonical_type} not supported"
# add end signal and concatenate together
conversation = source['system']
if canonical_type is not None:
conversation = conversation + '\n' + GUARD_RAIL_INSTRUCTION[canonical_type]
mask_role = source.get('mask', 'User')
header = f"{SYSTEM_TOKEN}{conversation}\n\n"
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, canonical_type)
# tokenize conversations
input_ids = tokenizer.text_to_ids(conversation)
target = copy.deepcopy(input_ids)
header_len = len(tokenizer.text_to_ids(header))

ids = []
tokenized_lens = []
for s in source['conversations']:
tokenized_sentence = tokenizer.text_to_ids(s["value"])
ids.append(torch.tensor(tokenized_sentence))
# remove one token as it adds an empty token in front
tokenized_lens.append(len(tokenized_sentence) - 1)
speakers = [sentence["from"] for sentence in source['conversations']]
assert mask_role in speakers, "mask role not in the conversation"
target = torch.LongTensor(target)
# not going to train on the header
target[:header_len] = IGNORE_INDEX
input_ids = torch.LongTensor(input_ids)

_mask_targets(target, tokenized_lens, speakers, header_len, ids, tokenizer, mask_role)
mask = (target != IGNORE_INDEX).bool()
assert mask.sum().item() != 0, "mask is empty"
return dict(input_ids=input_ids, mask=mask)


class GPTSFTChatDataset(GPTSFTDataset):
def _build_samples_mapping(self):
super()._build_samples_mapping()
assert hasattr(self.tokenizer, "vocab"), "tokenizer should have vocab property, not supported"
assert '<extra_id_0>' in self.tokenizer.vocab, "<extra_id_0> not in the tokenizer vocab. not supported"
assert '<extra_id_1>' in self.tokenizer.vocab, "<extra_id_1> not in the tokenizer vocab. not supported"

def _process_example(self, example):
"""
Create an example by concatenating text and answer.
Truncation is carried out when needed, but it is performed only on the prompt side.
BOS, EOS, and SEP, are added if specified.
"""
result = preprocess(example, self.tokenizer)

return result

def collate_fn(self, batch):
input_ids = [item['input_ids'][:-1].tolist() for item in batch]
labels = [item['input_ids'][1:].tolist() for item in batch]
loss_mask = [item['mask'][1:].tolist() for item in batch]

max_length = max([len(x) for x in input_ids])
if max_length > self.max_seq_length:
# truncate the sequences if it is longer than max_seq_length
input_ids = [x[: self.max_seq_length] for x in input_ids]
labels = [x[: self.max_seq_length] for x in labels]
loss_mask = [x[: self.max_seq_length] for x in loss_mask]
# increase max length to nearest multiple of 4 or 8
if self.pad_to_max_length:
max_length = self.max_seq_length
else:
max_length = min(self.max_seq_length, self._round_to_nearest(max_length, 8))
assert max_length <= self.max_seq_length

attention_mask = [self._create_attention_mask(max_length) for _ in batch]
attention_mask = torch.stack(attention_mask)
position_ids = [list(range(max_length)) for _ in batch]
position_ids = torch.LongTensor(position_ids)
input_ids = torch.LongTensor(
self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id)
)
labels = torch.LongTensor(self._collate_item(labels, max_length=max_length, pad_id=self.tokenizer.eos_id))
loss_mask = torch.LongTensor(self._collate_item(loss_mask, max_length=max_length, pad_id=0))

processed_batch = {
'tokens': input_ids,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
}

return processed_batch
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_datasets_weights_and_num_samples,
)
from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import (
MegatronPretrainingBatchSampler,
Expand Down Expand Up @@ -234,7 +235,11 @@ def _build_dataset(self, data_cfg, is_train=True):
num_train_samples_per_dataset = [[None]] * len(data_cfg.file_names)

for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset):
dataset = GPTSFTDataset(
if self.cfg.data.chat:
dataset_cls = GPTSFTChatDataset
else:
dataset_cls = GPTSFTDataset
dataset = dataset_cls(
file_path=file_path,
tokenizer=self.tokenizer,
max_seq_length=data_cfg.max_seq_length,
Expand Down
84 changes: 84 additions & 0 deletions nemo/collections/nlp/modules/common/chat_css.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

CSS = """
#chatbot .hll { background-color: #ffffcc }
#chatbot .c { color: #408080; font-style: italic }
#chatbot .err { border: 1px solid #FF0000 }
#chatbot .k { color: #008000; font-weight: bold }
#chatbot .o { color: #666666 }
#chatbot .ch { color: #408080; font-style: italic }
#chatbot .cm { color: #408080; font-style: italic }
#chatbot .cp { color: #BC7A00 }
#chatbot .cpf { color: #408080; font-style: italic }
#chatbot .c1 { color: #408080; font-style: italic }
#chatbot .cs { color: #408080; font-style: italic }
#chatbot .gd { color: #A00000 }
#chatbot .ge { font-style: italic }
#chatbot .gr { color: #FF0000 }
#chatbot .gh { color: #000080; font-weight: bold }
#chatbot .gi { color: #00A000 }
#chatbot .go { color: #888888 }
#chatbot .gp { color: #000080; font-weight: bold }
#chatbot .gs { font-weight: bold }
#chatbot .gu { color: #800080; font-weight: bold }
#chatbot .gt { color: #0044DD }
#chatbot .kc { color: #008000; font-weight: bold }
#chatbot .kd { color: #008000; font-weight: bold }
#chatbot .kn { color: #008000; font-weight: bold }
#chatbot .kp { color: #008000 }
#chatbot .kr { color: #008000; font-weight: bold }
#chatbot .kt { color: #B00040 }
#chatbot .m { color: #666666 }
#chatbot .s { color: #BA2121 }
#chatbot .na { color: #7D9029 }
#chatbot .nb { color: #008000 }
#chatbot .nc { color: #0000FF; font-weight: bold }
#chatbot .no { color: #880000 }
#chatbot .nd { color: #AA22FF }
#chatbot .ni { color: #999999; font-weight: bold }
#chatbot .ne { color: #D2413A; font-weight: bold }
#chatbot .nf { color: #0000FF }
#chatbot .nl { color: #A0A000 }
#chatbot .nn { color: #0000FF; font-weight: bold }
#chatbot .nt { color: #008000; font-weight: bold }
#chatbot .nv { color: #19177C }
#chatbot .ow { color: #AA22FF; font-weight: bold }
#chatbot .w { color: #bbbbbb }
#chatbot .mb { color: #666666 }
#chatbot .mf { color: #666666 }
#chatbot .mh { color: #666666 }
#chatbot .mi { color: #666666 }
#chatbot .mo { color: #666666 }
#chatbot .sa { color: #BA2121 }
#chatbot .sb { color: #BA2121 }
#chatbot .sc { color: #BA2121 }
#chatbot .dl { color: #BA2121 }
#chatbot .sd { color: #BA2121; font-style: italic }
#chatbot .s2 { color: #BA2121 }
#chatbot .se { color: #BB6622; font-weight: bold }
#chatbot .sh { color: #BA2121 }
#chatbot .si { color: #BB6688; font-weight: bold }
#chatbot .sx { color: #008000 }
#chatbot .sr { color: #BB6688 }
#chatbot .s1 { color: #BA2121 }
#chatbot .ss { color: #19177C }
#chatbot .bp { color: #008000 }
#chatbot .fm { color: #0000FF }
#chatbot .vc { color: #19177C }
#chatbot .vg { color: #19177C }
#chatbot .vi { color: #19177C }
#chatbot .vm { color: #19177C }
#chatbot .il { color: #666666 }
"""
Loading