From e283ddc5598fcb294f5449d14b8a4ac0dc0b6095 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 12 May 2023 12:50:29 -0300 Subject: [PATCH] Change how spaces are handled in continue/generation attempts --- modules/chat.py | 19 ++++++++++--------- modules/shared.py | 2 +- server.py | 2 +- settings-template.json | 2 +- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 6be3db53..5dbb517c 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -185,18 +185,13 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): # Generate for i in range(state['chat_generation_attempts']): reply = None - for j, reply in enumerate(generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)): + for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)): reply = cumulative_reply + reply # Extracting the reply reply, next_character_found = extract_message_from_reply(reply, state) visible_reply = re.sub("(||{{user}})", state['name1'], reply) visible_reply = apply_extensions("output", visible_reply) - if _continue: - sep = ' ' if last_reply[0][-1] not in [' ', '\n'] else '' - reply = last_reply[0] + sep + reply - sep = ' ' if last_reply[1][-1] not in [' ', '\n'] else '' - visible_reply = last_reply[1] + sep + visible_reply # We need this global variable to handle the Stop event, # otherwise gradio gets confused @@ -209,7 +204,11 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): shared.history['internal'].append(['', '']) shared.history['visible'].append(['', '']) - if not (j == 0 and visible_reply.strip() == ''): + if _continue: + shared.history['internal'][-1] = [text, last_reply[0] + reply] + shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply] + yield shared.history['visible'] + elif not (j == 0 and visible_reply.strip() == ''): shared.history['internal'][-1] = [text, reply] shared.history['visible'][-1] = [visible_text, visible_reply] yield shared.history['visible'] @@ -217,7 +216,9 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): if next_character_found: break - if reply is not None: + if reply in [None, '']: + break + else: cumulative_reply = reply yield shared.history['visible'] @@ -239,7 +240,7 @@ def impersonate_wrapper(text, state): cumulative_reply = text for i in range(state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True): + for reply in generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True): reply = cumulative_reply + reply reply, next_character_found = extract_message_from_reply(reply, state) yield reply diff --git a/modules/shared.py b/modules/shared.py index 97ed9719..3c6ca4bc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -59,7 +59,7 @@ settings = { 'chat_prompt_size_max': 2048, 'chat_generation_attempts': 1, 'chat_generation_attempts_min': 1, - 'chat_generation_attempts_max': 5, + 'chat_generation_attempts_max': 10, 'default_extensions': [], 'chat_default_extensions': ["gallery"], 'presets': { diff --git a/server.py b/server.py index 6b89fff5..50c32895 100644 --- a/server.py +++ b/server.py @@ -602,7 +602,7 @@ def create_interface(): shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) with gr.Column(): - shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') + shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations') shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character') create_settings_menus(default_preset) diff --git a/settings-template.json b/settings-template.json index 635260e5..5afe6ed9 100644 --- a/settings-template.json +++ b/settings-template.json @@ -26,7 +26,7 @@ "chat_prompt_size_max": 2048, "chat_generation_attempts": 1, "chat_generation_attempts_min": 1, - "chat_generation_attempts_max": 5, + "chat_generation_attempts_max": 10, "default_extensions": [], "chat_default_extensions": [ "gallery"