Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dialogue dataset #6654

Merged
merged 32 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
38fa1b6
chatbot interface
yidong72 Apr 4, 2023
0069f3a
latest gradio
yidong72 Apr 4, 2023
87815b4
default greedy
yidong72 Apr 4, 2023
254007f
better chatbot
yidong72 Apr 5, 2023
9599032
Merge branch 'main' into chatbot_ui
yidong72 Apr 6, 2023
37f1905
handle preamble
yidong72 Apr 10, 2023
4770dac
Merge branch 'main' into chatbot_ui
yidong72 Apr 11, 2023
91993ef
added chatbot training capablity
yidong72 Apr 12, 2023
3f7cf33
added chatbot ui
yidong72 Apr 15, 2023
864faec
remove debug code
yidong72 Apr 15, 2023
757122e
default human
yidong72 Apr 16, 2023
be4b7eb
use special token for roles
yidong72 Apr 18, 2023
2ddbaa1
special tokens
yidong72 Apr 19, 2023
1c9260f
fix name
yidong72 Apr 25, 2023
6361295
new chat dataset
yidong72 May 4, 2023
93ecc33
fix the system token
yidong72 May 5, 2023
67ff1c8
upgrade gradio
yidong72 May 6, 2023
8020920
save the chat history
yidong72 May 8, 2023
57a97d6
Merge branch 'chatbot_ds' of github.com:NVIDIA/NeMo into chatbot_ds
yidong72 May 8, 2023
d3f91ee
update ui
May 9, 2023
48ef830
update chat interface
yidong72 May 9, 2023
21d476b
handles canonical form
yidong72 May 11, 2023
c550faf
Merge branch 'chatbot_ds' of github.com:NVIDIA/NeMo into chatbot_ds
yidong72 May 11, 2023
9bb735e
new sft chatbot
yidong72 May 15, 2023
6e47a60
Merge branch 'main' into chatbot_ds
yidong72 May 15, 2023
61b5094
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2023
2850396
change format
yidong72 May 15, 2023
91a968c
Merge branch 'chatbot_ds' of github.com:NVIDIA/NeMo into chatbot_ds
yidong72 May 15, 2023
33d1767
check extra_id in the tokenizer
yidong72 May 15, 2023
693b4a4
added vocab property check
yidong72 May 15, 2023
da64c7b
added missing file
yidong72 May 15, 2023
eac1407
Merge branch 'main' into chatbot_ds
MaximumEntropy May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
handle preamble
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
yidong72 committed Apr 10, 2023
commit 37f190537260366f75469383deac9d9c9c12ae57
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def main(cfg) -> None:
if cfg.web_server:
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=get_chatbot_demo,
# target=get_chatbot_demo,
target=get_demo,
daemon=True,
args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop),
)
Expand Down
83 changes: 56 additions & 27 deletions nemo/collections/nlp/modules/common/megatron_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,57 @@
__all__ = ['RetroDemoWebApp', 'get_demo']


def create_gen_function(port=5555):
def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, end_strings):
data = {
"sentences": [prompt],
"tokens_to_generate": int(token_to_gen),
"temperature": temp,
"add_BOS": add_BOS,
"top_k": top_k,
"top_p": top_p,
"greedy": greedy,
"all_probs": False,
"repetition_penalty": repetition,
"min_tokens_to_generate": int(min_tokens),
"end_strings": [i.strip() for i in end_strings.split(',') if len(i) != 0],
}
response = text_generation(data, port=port)
sentences = response['sentences']
return sentences[0]
DEFAULT_SYSTEM = "As an AI assistant, you strive to be helpful, polite, honest, sophisticated, emotionally aware, and humble yet knowledgeable. You are always happy to assist with almost anything, and you will do your best to understand exactly what is needed. Your goal is to avoid providing false or misleading information, and you will indicate when you are unsure about the correct response. However, you are practical and genuinely try to perform your duties to the best of your abilities, and you don't let excessive caution get in the way of being useful to the user."
SYSTEM_TOKEN = '<extra_id_0>'
HUMAN_TOKEN = '<extra_id_1>'
ASSITANT_TOKEN = '<extra_id_2>'


def create_gen_function(port=5555, chat=False):
if chat:
def get_generation(prompt, preamble, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, end_strings):
if preamble is not None and preamble != '':
prompt = SYSTEM_TOKEN + preamble + prompt
data = {
"sentences": [prompt],
"tokens_to_generate": int(token_to_gen),
"temperature": temp,
"add_BOS": add_BOS,
"top_k": top_k,
"top_p": top_p,
"greedy": greedy,
"all_probs": False,
"repetition_penalty": repetition,
"min_tokens_to_generate": int(min_tokens),
"end_strings": [i.strip() for i in end_strings.split(',') if len(i) != 0],
}
response = text_generation(data, port=port)
sentences = response['sentences']
bot_message = sentences[0]
bot_message = bot_message[len(prompt):]
return bot_message
else:
def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, end_strings):
data = {
"sentences": [prompt],
"tokens_to_generate": int(token_to_gen),
"temperature": temp,
"add_BOS": add_BOS,
"top_k": top_k,
"top_p": top_p,
"greedy": greedy,
"all_probs": False,
"repetition_penalty": repetition,
"min_tokens_to_generate": int(min_tokens),
"end_strings": [i.strip() for i in end_strings.split(',') if len(i) != 0],
}
response = text_generation(data, port=port)
sentences = response['sentences']
bot_message = sentences[0]
bot_message = bot_message[len(prompt):]
return bot_message
return get_generation



def get_demo(share, username, password, server_port=5555, web_port=9889, loop=None):
Expand Down Expand Up @@ -73,7 +104,7 @@ def get_demo(share, username, password, server_port=5555, web_port=9889, loop=No
output_box = gr.Textbox(value="", label="Output")
btn = gr.Button(value="Submit")
btn.click(
create_gen_function(server_port),
create_gen_function(server_port, chat=False),
inputs=[
input_prompt,
greedy_flag,
Expand Down Expand Up @@ -108,34 +139,32 @@ def get_chatbot_demo(share, username, password, server_port=5555, web_port=9889,
)
end_strings = gr.Textbox(label="End strings (comma separated)", value="<|endoftext|>,", lines=1,)
with gr.Column(scale=1, min_width=800):
preamble = gr.Textbox(label="System", value=DEFAULT_SYSTEM, lines=2,)
chatbot = Chatbot()
msg = gr.Textbox()
msg = gr.Textbox(label="User", value="", lines=1,)
clear = gr.Button("Clear")
HUMAN_TOKEN = '<extra_id_1>'
ASSITANT_TOKEN = '<extra_id_2>'

def user(user_message, history):
return "", history + [[user_message, None]]

def bot(history, greedy_flag, add_BOS, token_to_gen, min_token_to_gen, temperature, top_p, top_k, repetition_penality, end_strings):
def bot(history, preamble, greedy_flag, add_BOS, token_to_gen, min_token_to_gen, temperature, top_p, top_k, repetition_penality, end_strings):
prompts = history[:-1]
prompt_text = ''
for prompt in prompts:
prompt_text += HUMAN_TOKEN + prompt[
0].replace('<br>', '\n') + '\n' + ASSITANT_TOKEN + prompt[1].replace('<br>', '\n') + '\n'
prompt_text += HUMAN_TOKEN + history[-1][
0].replace('<br>', '\n') + '\n' + ASSITANT_TOKEN
bot_message = create_gen_function(server_port)(
prompt_text, greedy_flag, add_BOS,
bot_message = create_gen_function(server_port, chat=True)(
prompt_text, preamble, greedy_flag, add_BOS,
token_to_gen, min_token_to_gen,
temperature, top_p, top_k,
repetition_penality, end_strings)
bot_message = bot_message[len(prompt_text):]
history[-1][1] = bot_message
return history

msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
bot, [chatbot, greedy_flag, add_BOS, token_to_gen, min_token_to_gen, temperature, top_p, top_k, repetition_penality, end_strings], chatbot
bot, [chatbot, preamble, greedy_flag, add_BOS, token_to_gen, min_token_to_gen, temperature, top_p, top_k, repetition_penality, end_strings], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch(share=share, server_port=web_port, server_name='0.0.0.0', auth=(username, password))
Expand Down