From b1ee674d75fca92c638e967118addb796a34a9b0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 03:05:47 -0300 Subject: [PATCH] Make interface state (mostly) persistent on page reload --- modules/shared.py | 3 +++ modules/ui.py | 24 ++++++++++++++++++++---- server.py | 6 +++++- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 82acf3c0..6b0c6f06 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -20,6 +20,9 @@ processing_message = '*Is typing...*' # UI elements (buttons, sliders, HTML, etc) gradio = {} +# For keeping the values of UI elements on page reload +persistent_interface_state = {} + # Generation input parameters input_params = [] diff --git a/modules/ui.py b/modules/ui.py index 5db36b3e..0ddcc833 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -33,9 +33,10 @@ def list_model_elements(): def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens'] + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu'] if chat: - elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template'] + elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu'] + elements += list_model_elements() return elements @@ -44,11 +45,26 @@ def gather_interface_values(*args): output = {} for i, element in enumerate(shared.input_elements): output[element] = args[i] + + shared.persistent_interface_state = output return output -def apply_interface_values(state): - return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())] +def apply_interface_values(state, use_persistent=False): + if use_persistent: + state = shared.persistent_interface_state + + elements = list_interface_input_elements(chat=shared.is_chat()) + if len(state) == 0: + return [gr.update() for k in elements] # Dummy, do nothing + else: + if use_persistent and 'mode' in state: + if state['mode'] == 'instruct': + return [state[k] if k not in ['character_menu'] else gr.update() for k in elements] + else: + return [state[k] if k not in ['instruction_template'] else gr.update() for k in elements] + else: + return [state[k] for k in elements] class ToolButton(gr.Button, gr.components.FormComponent): diff --git a/server.py b/server.py index 573552ff..8308c9b8 100644 --- a/server.py +++ b/server.py @@ -32,6 +32,7 @@ import time import traceback import zipfile from datetime import datetime +from functools import partial from pathlib import Path import psutil @@ -846,6 +847,8 @@ def create_interface(): shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") + shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) + # Launch the interface shared.gradio['interface'].queue() if shared.args.listen: @@ -855,7 +858,6 @@ def create_interface(): if __name__ == "__main__": - # Loading custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): @@ -900,9 +902,11 @@ if __name__ == "__main__": print('The following models are available:\n') for i, model in enumerate(available_models): print(f'{i+1}. {model}') + print(f'\nWhich one do you want to load? 1-{len(available_models)}\n') i = int(input()) - 1 print() + shared.model_name = available_models[i] # If any model has been selected, load it