Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include the scripts for preprocessing OAST and unit tests for chat sft datasets #7112

Merged
merged 23 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
annoatation handles lang
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
yidong72 committed Jul 25, 2023
commit cb67e93573aebfc6bb64ea10b8973a6fb6642e4c
7 changes: 6 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,12 @@ def main(cfg) -> None:
'assistant': cfg.chatbot_config.assistant,
'system': cfg.chatbot_config.system,
}
web_ui = partial(get_chatbot_demo, defaults=defaults, value=cfg.chatbot_config.value, attributes=cfg.chatbot_config.attributes)
web_ui = partial(
get_chatbot_demo,
defaults=defaults,
value=cfg.chatbot_config.value,
attributes=cfg.chatbot_config.attributes,
)
else:
web_ui = get_demo
loop = asyncio.new_event_loop()
Expand Down
12 changes: 7 additions & 5 deletions nemo/collections/nlp/modules/common/megatron_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,14 @@ def get_chatbot_demo(
widgets = []
for item in attributes:
if item.type == 'int':
slider = gr.Slider(minimum=item.min, maximum=item.max, step=1, value=item.default, label=item.name)
slider = gr.Slider(
minimum=item.min, maximum=item.max, step=1, value=item.default, label=item.name
)
widgets.append(slider)
elif item.type == 'list':
dropdown = gr.Dropdown(item.choices, label=item.name, default=item.default, value=item.default)
dropdown = gr.Dropdown(
item.choices, label=item.name, default=item.default, value=item.default
)
widgets.append(dropdown)
used_value = gr.CheckboxGroup(keys, value=keys)

Expand All @@ -244,9 +248,7 @@ def change_visibility(x):
return values

used_value.change(
change_visibility,
inputs=[used_value],
outputs=widgets,
change_visibility, inputs=[used_value], outputs=widgets,
)

def set_sampling(x):
Expand Down
110 changes: 106 additions & 4 deletions scripts/nlp_language_modeling/sft/attribute_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,44 @@

from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import text_generation

langs = [
'ar',
'bg',
'bn',
'ca',
'cs',
'da',
'de',
'el',
'en',
'eo',
'es',
'eu',
'fa',
'fi',
'fr',
'gl',
'he',
'hu',
'id',
'it',
'ja',
'ko',
'nb',
'nl',
'pl',
'pt',
'ro',
'ru',
'sk',
'sv',
'th',
'tr',
'uk',
'vi',
'zh',
]

SFT_PREFIX = """<extra_id_0>System
{system_message}"""

Expand Down Expand Up @@ -76,6 +114,7 @@
'fails_task',
'political_content',
'moral_judgement',
'lang',
]


Expand Down Expand Up @@ -139,10 +178,11 @@ def request(prompts, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top


class Worker(object):
def __init__(self, host='localhost', port=5555, progress_bar=None, output_file=None):
def __init__(self, host='localhost', port=5555, progress_bar=None, output_file=None, process_lang=False):
self.req = create_gen_function(host=host, port=port)
self.fout = open(output_file, "a", encoding='utf-8')

Check warning

Code scanning / CodeQL

File is not always closed Warning

File is opened but is not closed.
self.progress_bar = progress_bar
self.process_lang = process_lang

def process_result(self, batch):
while True:
Expand All @@ -151,7 +191,7 @@ def process_result(self, batch):
turns = [i['turn'] for i in batch]
prompts = [i['prompt'] for i in batch]

for label_id in range(1, len(selected_keys) + 1):
for label_id in range(1, len(selected_keys)):
results = self.req(
prompts,
greedy=True,
Expand Down Expand Up @@ -206,6 +246,61 @@ def process_result(self, batch):
prompts = filtered_prompts
current_values = filtered_current_values

if self.process_lang:
results = self.req(
prompts,
greedy=True,
add_BOS=False,
token_to_gen=1,
min_tokens=1,
temp=0.1,
top_p=1.0,
top_k=1,
repetition=1.0,
end_strings=["<extra_id_1>", "<|endoftext|>"],
)
# get current value from result
current_values = []
for result in results:
# promblem result[-1] is '\n'
if result.endswith('\n'):
result = result[:-1] + '@'
current_values.append(result.split('\n')[-1])

nums = []
for result in results:
# promblem result[-1] is '\n'
current_val = result.split('quality')[-1]
current_val = 'quality' + current_val
# remove whatever after new line
current_val = current_val.split('\n')[0].strip()
# remove everything that is >= selected_keys[label_id]
splits = current_val.split(',')
filtered = []
for item in splits:
filtered.append(item)
if item.split(':')[0] == selected_keys[label_id]:
nums.append(item.split(':')[1])
break
current_val = '<extra_id_2>' + ','.join(filtered)
current_values.append(current_val)

filtered_items = []
filtered_turns = []
filtered_prompts = []
filtered_current_values = []

for result, item, turn, num, current_value in zip(results, items, turns, nums, current_values):
if num not in langs:
print(f'error {num} not in langs')
continue
filtered_current_values.append(current_value)
filtered_items.append(item)
filtered_turns.append(turn)
items = filtered_items
turns = filtered_turns
current_values = filtered_current_values

batch = []
for item, turn, current_value in zip(items, turns, current_values):
response_text = current_value[12:]
Expand All @@ -232,7 +327,12 @@ def process_result(self, batch):


def main(
batch_size=1, host='localhost', input_file_name='input.jsonl', output_file_name='output.jsonl', port_num=1424
batch_size=1,
host='batch-block1-10453',
input_file_name='input.jsonl',
output_file_name='output.jsonl',
port_num=1424,
process_lang=True,
):
input_data = load_data(f'{input_file_name}')
output_path = f'{output_file_name}'
Expand All @@ -248,7 +348,9 @@ def main(

progress_bar = tqdm.tqdm(total=len(filter_data))

worker = Worker(host=host, port=port_num, progress_bar=progress_bar, output_file=output_path)
worker = Worker(
host=host, port=port_num, progress_bar=progress_bar, output_file=output_path, process_lang=process_lang
)
for batch_idx in range(0, len(filter_data), batch_size):
batch = [line for line in filter_data[batch_idx : batch_idx + batch_size]]
turns = [
Expand Down
Loading