From 2cb5b68ad98806a56580de2ae5db4c230bb7b9a1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 15 Dec 2023 01:01:45 -0300 Subject: [PATCH] Bug fix: when generation fails, save the sent message (#4915) --- modules/chat.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index ab0d70a9..7a44c03e 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -215,40 +215,47 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess yield output return - just_started = True visible_text = None stopping_strings = get_stopping_strings(state) is_stream = state['stream'] # Prepare the input - if not any((regenerate, _continue)): + if not (regenerate or _continue): visible_text = html.escape(text) # Apply extensions text, visible_text = apply_extensions('chat_input', text, visible_text, state) text = apply_extensions('input', text, state, is_chat=True) + output['internal'].append([text, '']) + output['visible'].append([visible_text, '']) + # *Is typing...* if loading_message: - yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} + yield { + 'visible': output['visible'][:-1] + [[output['visible'][-1][0], shared.processing_message]], + 'internal': output['internal'] + } else: text, visible_text = output['internal'][-1][0], output['visible'][-1][0] if regenerate: - output['visible'].pop() - output['internal'].pop() - - # *Is typing...* if loading_message: - yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']} + yield { + 'visible': output['visible'][:-1] + [[visible_text, shared.processing_message]], + 'internal': output['internal'][:-1] + [[text, '']] + } elif _continue: last_reply = [output['internal'][-1][1], output['visible'][-1][1]] if loading_message: - yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']} + yield { + 'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], + 'internal': output['internal'] + } # Generate the prompt kwargs = { '_continue': _continue, - 'history': output, + 'history': output if _continue else {k: v[:-1] for k, v in output.items()} } prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs) if prompt is None: @@ -270,12 +277,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess yield output return - if just_started: - just_started = False - if not _continue: - output['internal'].append(['', '']) - output['visible'].append(['', '']) - if _continue: output['internal'][-1] = [text, last_reply[0] + reply] output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]