Skip to content

Commit

Permalink
Bug fix in cli_demo.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
fyabc committed Mar 14, 2024
1 parent 3e3c205 commit be9a9e1
Showing 1 changed file with 51 additions and 28 deletions.
79 changes: 51 additions & 28 deletions examples/demo/cli_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,42 @@
'''
_HELP_MSG = '''\
Commands:
:help / :h Show this help message 显示帮助信息
:exit / :quit / :q Exit the demo 退出Demo
:clear / :cl Clear screen 清屏
:clear-his / :clh Clear history 清除对话历史
:history / :his Show history 显示对话历史
:seed Show current random seed 显示当前随机种子
:seed <N> Set random seed to <N> 设置随机种子
:conf Show current generation config 显示生成配置
:conf <key>=<value> Change generation config 修改生成配置
:reset-conf Reset generation config 重置生成配置
:help / :h Show this help message 显示帮助信息
:exit / :quit / :q Exit the demo 退出Demo
:clear / :cl Clear screen 清屏
:clear-history / :clh Clear history 清除对话历史
:history / :his Show history 显示对话历史
:seed Show current random seed 显示当前随机种子
:seed <N> Set random seed to <N> 设置随机种子
:conf Show current generation config 显示生成配置
:conf <key>=<value> Change generation config 修改生成配置
:reset-conf Reset generation config 重置生成配置
'''
_ALL_COMMAND_NAMES = [
'help', 'h', 'exit', 'quit', 'q', 'clear', 'cl', 'clear-history', 'clh', 'history', 'his',
'seed', 'conf', 'reset-conf',
]


def _setup_readline():
try:
import readline
except ImportError:
return

_matches = []

def _completer(text, state):
nonlocal _matches

if state == 0:
_matches = [cmd_name for cmd_name in _ALL_COMMAND_NAMES if cmd_name.startswith(text)]
if 0 <= state < len(_matches):
return _matches[state]
return None

readline.set_completer(_completer)
readline.parse_and_bind('tab: complete')


def _load_model_tokenizer(args):
Expand All @@ -57,13 +82,9 @@ def _load_model_tokenizer(args):
device_map=device_map,
resume_download=True,
).eval()
model.generation_config.max_new_tokens = 2048 # For chat.

config = GenerationConfig.from_pretrained(
args.checkpoint_path, resume_download=True,
)
config.max_new_tokens = 512 # For chat.

return model, tokenizer, config
return model, tokenizer


def _gc():
Expand Down Expand Up @@ -103,26 +124,26 @@ def _get_input() -> str:
print('[ERROR] Query is empty')


def _chat_stream(model, tokenizer, query, history, config):
def _chat_stream(model, tokenizer, query, history):
conversation = [
{'role': 'system', 'message': 'You are a helpful assistant.'},
{'role': 'system', 'content': 'You are a helpful assistant.'},
]
for query_h, response_h in history:
conversation.append({'role': 'user', 'message': query_h})
conversation.append({'role': 'assistant', 'message': response_h})
conversation.append({'role': 'user', 'message': query})
conversation.append({'role': 'user', 'content': query_h})
conversation.append({'role': 'assistant', 'content': response_h})
conversation.append({'role': 'user', 'content': query})
inputs = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors='pt'
return_tensors='pt',
)
inputs = inputs.to(model.device)
streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, timeout=60.0, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
input_ids=inputs,
streamer=streamer,
generation_config=config,
)
thread = Thread(model.generate, kwargs=generation_kwargs)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

for new_text in streamer:
Expand All @@ -140,9 +161,11 @@ def main():

history, response = [], ''

model, tokenizer, config = _load_model_tokenizer(args)
model, tokenizer = _load_model_tokenizer(args)
orig_gen_config = deepcopy(model.generation_config)

_setup_readline()

_clear_screen()
print(_WELCOME_MSG)

Expand Down Expand Up @@ -226,8 +249,8 @@ def main():
print(f"\nQwen1.5-Chat: ", end="")
try:
partial_text = ''
for new_text in _chat_stream(model, tokenizer, query, history, config):
print(new_text, end='')
for new_text in _chat_stream(model, tokenizer, query, history):
print(new_text, end='', flush=True)
partial_text += new_text
response = partial_text
print()
Expand Down

0 comments on commit be9a9e1

Please sign in to comment.