Make interface state (mostly) persistent on page reload

This commit is contained in:
oobabooga 2023-04-24 03:05:47 -03:00
parent 47809e28aa
commit b1ee674d75
3 changed files with 28 additions and 5 deletions

View file

@ -20,6 +20,9 @@ processing_message = '*Is typing...*'
# UI elements (buttons, sliders, HTML, etc) # UI elements (buttons, sliders, HTML, etc)
gradio = {} gradio = {}
# For keeping the values of UI elements on page reload
persistent_interface_state = {}
# Generation input parameters # Generation input parameters
input_params = [] input_params = []

View file

@ -33,9 +33,10 @@ def list_model_elements():
def list_interface_input_elements(chat=False): def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens'] elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
if chat: if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template'] elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
elements += list_model_elements() elements += list_model_elements()
return elements return elements
@ -44,11 +45,26 @@ def gather_interface_values(*args):
output = {} output = {}
for i, element in enumerate(shared.input_elements): for i, element in enumerate(shared.input_elements):
output[element] = args[i] output[element] = args[i]
shared.persistent_interface_state = output
return output return output
def apply_interface_values(state): def apply_interface_values(state, use_persistent=False):
return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())] if use_persistent:
state = shared.persistent_interface_state
elements = list_interface_input_elements(chat=shared.is_chat())
if len(state) == 0:
return [gr.update() for k in elements] # Dummy, do nothing
else:
if use_persistent and 'mode' in state:
if state['mode'] == 'instruct':
return [state[k] if k not in ['character_menu'] else gr.update() for k in elements]
else:
return [state[k] if k not in ['instruction_template'] else gr.update() for k in elements]
else:
return [state[k] for k in elements]
class ToolButton(gr.Button, gr.components.FormComponent): class ToolButton(gr.Button, gr.components.FormComponent):

View file

@ -32,6 +32,7 @@ import time
import traceback import traceback
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from functools import partial
from pathlib import Path from pathlib import Path
import psutil import psutil
@ -846,6 +847,8 @@ def create_interface():
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)
# Launch the interface # Launch the interface
shared.gradio['interface'].queue() shared.gradio['interface'].queue()
if shared.args.listen: if shared.args.listen:
@ -855,7 +858,6 @@ def create_interface():
if __name__ == "__main__": if __name__ == "__main__":
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -900,9 +902,11 @@ if __name__ == "__main__":
print('The following models are available:\n') print('The following models are available:\n')
for i, model in enumerate(available_models): for i, model in enumerate(available_models):
print(f'{i+1}. {model}') print(f'{i+1}. {model}')
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n') print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
i = int(input()) - 1 i = int(input()) - 1
print() print()
shared.model_name = available_models[i] shared.model_name = available_models[i]
# If any model has been selected, load it # If any model has been selected, load it