diff --git a/chat.py b/chat.py index 14df7108..eb88a3a6 100644 --- a/chat.py +++ b/chat.py @@ -11,6 +11,8 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200) args = types.SimpleNamespace() +from prompt_toolkit import prompt + print('\n\nChatRWKV project: https://github.com/BlinkDL/ChatRWKV') ######################################################################################################## @@ -427,7 +429,7 @@ def on_message(message): print(f'Ready - {CHAT_LANG} {args.RUN_DEVICE} {args.FLOAT_MODE} QA_PROMPT={QA_PROMPT} {args.MODEL_NAME}\n') while True: - msg = input(f'{user}{interface} ') + msg = prompt(f'{user}{interface} ') if len(msg.strip()) > 0: on_message(msg) else: