From 19f78684e6d32d43cd5ce3ae82d6f2216421b9ae Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 2 Jun 2023 13:58:08 -0300 Subject: [PATCH] Add "Start reply with" feature to chat mode --- modules/chat.py | 11 ++++++++--- modules/shared.py | 1 + server.py | 10 +++++++--- settings-template.yaml | 1 + 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 2a25327a..f3388737 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -277,7 +277,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa yield output -def impersonate_wrapper(text, state): +def impersonate_wrapper(text, start_with, state): if shared.model_name == 'None' or shared.model is None: logger.error("No model is loaded! Select one in the Model tab.") yield '' @@ -322,8 +322,13 @@ def generate_chat_reply(text, history, state, regenerate=False, _continue=False, yield history -# Same as above but returns HTML -def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): +# Same as above but returns HTML for the UI +def generate_chat_reply_wrapper(text, start_with, state, regenerate=False, _continue=False): + if start_with != '' and _continue == False: + _continue = True + send_dummy_message(text) + send_dummy_reply(start_with) + for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)): if i != 0: shared.history = copy.deepcopy(history) diff --git a/modules/shared.py b/modules/shared.py index a7df12e1..9a025587 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -55,6 +55,7 @@ settings = { 'truncation_length_min': 0, 'truncation_length_max': 8192, 'mode': 'chat', + 'start_with': '', 'chat_style': 'cai-chat', 'instruction_template': 'None', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', diff --git a/server.py b/server.py index 4cb235d6..1c8a5fe0 100644 --- a/server.py +++ b/server.py @@ -626,8 +626,12 @@ def create_interface(): shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) - shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'], value=shared.settings['mode'] if shared.settings['mode'] in ['chat', 'instruct', 'chat-instruct'] else 'chat', label='Mode', info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under "Chat settings" must match the current model.') - shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct') + with gr.Row(): + shared.gradio['start_with'] = gr.Textbox(label='Start reply with', placeholder='Sure thing!', value=shared.settings['start_with']) + + with gr.Row(): + shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'], value=shared.settings['mode'] if shared.settings['mode'] in ['chat', 'instruct', 'chat-instruct'] else 'chat', label='Mode', info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under "Chat settings" must match the current model.') + shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct') with gr.Tab('Chat settings', elem_id='chat-settings'): with gr.Row(): @@ -825,7 +829,7 @@ def create_interface(): # chat mode event handlers if shared.is_chat(): - shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']] + shared.input_params = [shared.gradio[k] for k in ['Chat input', 'start_with', 'interface_state']] clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] shared.reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode', 'chat_style']] diff --git a/settings-template.yaml b/settings-template.yaml index aba2f489..84cf0105 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -22,6 +22,7 @@ truncation_length: 2048 truncation_length_min: 0 truncation_length_max: 8192 mode: chat +start_with: '' chat_style: cai-chat instruction_template: None chat-instruct_command: 'Continue the chat dialogue below. Write a single reply for