diff --git a/modules/chat.py b/modules/chat.py index 4f0434ba..8756adb3 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -205,7 +205,7 @@ def get_stopping_strings(state): return list(set(stopping_strings)) -def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True): +def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False): history = state['history'] output = copy.deepcopy(history) output = apply_extensions('history', output) @@ -256,7 +256,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess # Generate reply = None - for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)): + for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True, for_ui=for_ui)): # Extract the reply visible_reply = reply @@ -311,7 +311,7 @@ def impersonate_wrapper(text, state): return -def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True): +def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False): history = state['history'] if regenerate or _continue: text = '' @@ -319,7 +319,7 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_ yield history return - for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message): + for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui): yield history @@ -351,7 +351,7 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): send_dummy_message(text, state) send_dummy_reply(state['start_with'], state) - for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True)): + for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)): yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history diff --git a/modules/text_generation.py b/modules/text_generation.py index 3815fe70..f576ba83 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -33,7 +33,7 @@ def generate_reply(*args, **kwargs): shared.generation_lock.release() -def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False): +def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False): # Find the appropriate generation function generate_func = apply_extensions('custom_generate_reply') @@ -96,7 +96,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap # Limit updates to 24 or 5 per second to avoid lag in the Gradio UI # API updates are not limited else: - min_update_interval = 0 if not escape_html else 0.2 if (shared.args.listen or shared.args.share) else 0.0417 + min_update_interval = 0 if not for_ui else 0.2 if (shared.args.listen or shared.args.share) else 0.0417 if cur_time - last_update > min_update_interval: last_update = cur_time yield reply @@ -178,7 +178,7 @@ def generate_reply_wrapper(question, state, stopping_strings=None): reply = question if not shared.is_seq2seq else '' yield formatted_outputs(reply, shared.model_name) - for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True): + for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True, for_ui=True): if not shared.is_seq2seq: reply = question + reply