diff --git a/modules/chat.py b/modules/chat.py index bddc3132..c431d2d0 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -166,53 +166,54 @@ def generate_chat_prompt(user_input, state, **kwargs): prompt = remove_extra_bos(prompt) return prompt - # Handle truncation - max_length = get_max_prompt_length(state) prompt = make_prompt(messages) - encoded_length = get_encoded_length(prompt) - while len(messages) > 0 and encoded_length > max_length: + # Handle truncation + if shared.tokenizer is not None: + max_length = get_max_prompt_length(state) + encoded_length = get_encoded_length(prompt) + while len(messages) > 0 and encoded_length > max_length: - # Remove old message, save system message - if len(messages) > 2 and messages[0]['role'] == 'system': - messages.pop(1) + # Remove old message, save system message + if len(messages) > 2 and messages[0]['role'] == 'system': + messages.pop(1) - # Remove old message when no system message is present - elif len(messages) > 1 and messages[0]['role'] != 'system': - messages.pop(0) + # Remove old message when no system message is present + elif len(messages) > 1 and messages[0]['role'] != 'system': + messages.pop(0) - # Resort to truncating the user input - else: + # Resort to truncating the user input + else: - user_message = messages[-1]['content'] + user_message = messages[-1]['content'] - # Bisect the truncation point - left, right = 0, len(user_message) - 1 + # Bisect the truncation point + left, right = 0, len(user_message) - 1 - while right - left > 1: - mid = (left + right) // 2 + while right - left > 1: + mid = (left + right) // 2 - messages[-1]['content'] = user_message[mid:] + messages[-1]['content'] = user_message[mid:] + prompt = make_prompt(messages) + encoded_length = get_encoded_length(prompt) + + if encoded_length <= max_length: + right = mid + else: + left = mid + + messages[-1]['content'] = user_message[right:] prompt = make_prompt(messages) encoded_length = get_encoded_length(prompt) - - if encoded_length <= max_length: - right = mid + if encoded_length > max_length: + logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n") + raise ValueError else: - left = mid + logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}, available context length: {max_length}.") + break - messages[-1]['content'] = user_message[right:] prompt = make_prompt(messages) encoded_length = get_encoded_length(prompt) - if encoded_length > max_length: - logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n") - raise ValueError - else: - logger.warning(f"The input has been truncated. Context length: {state['truncation_length']}, max_new_tokens: {state['max_new_tokens']}, available context length: {max_length}.") - break - - prompt = make_prompt(messages) - encoded_length = get_encoded_length(prompt) if also_return_rows: return prompt, [message['content'] for message in messages] diff --git a/modules/ui_chat.py b/modules/ui_chat.py index a1b1af97..42e5cae2 100644 --- a/modules/ui_chat.py +++ b/modules/ui_chat.py @@ -109,7 +109,7 @@ def create_chat_settings_ui(): with gr.Row(): with gr.Column(): with gr.Row(): - shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', value='Select template to load...', elem_classes='slim-dropdown') + shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='Select template to load...', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu) shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button') shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)