From 218bd64bd1facc2bfe1ba04e8fbfeb27201c9e85 Mon Sep 17 00:00:00 2001 From: LaaZa Date: Tue, 9 May 2023 18:52:35 +0000 Subject: [PATCH] Add the option to not automatically load the selected model (#1762) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- modules/shared.py | 1 + server.py | 42 +++++++++++++++++++++++++++++++----------- settings-template.json | 1 + 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index fd494b9c..ee69278b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -31,6 +31,7 @@ input_params = [] need_restart = False settings = { + 'autoload_model': True, 'max_new_tokens': 200, 'max_new_tokens_min': 1, 'max_new_tokens_max': 2000, diff --git a/server.py b/server.py index 6c94738a..cbced1ff 100644 --- a/server.py +++ b/server.py @@ -51,17 +51,24 @@ from modules.models import load_model, load_soft_prompt, unload_model from modules.text_generation import encode, generate_reply, stop_everything_event -def load_model_wrapper(selected_model): - try: - yield f"Loading {selected_model}..." - shared.model_name = selected_model - unload_model() - if selected_model != '': - shared.model, shared.tokenizer = load_model(shared.model_name) +def load_model_wrapper(selected_model, autoload=False): + if not autoload: + yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it." + return - yield f"Successfully loaded {selected_model}" - except: - yield traceback.format_exc() + if selected_model == 'None': + yield "No model selected" + else: + try: + yield f"Loading {selected_model}..." + shared.model_name = selected_model + unload_model() + if selected_model != '': + shared.model, shared.tokenizer = load_model(shared.model_name) + + yield f"Successfully loaded {selected_model}" + except: + yield traceback.format_exc() def load_lora_wrapper(selected_loras): @@ -292,6 +299,7 @@ def create_model_menus(): with gr.Row(): shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs') with gr.Row(): + load = gr.Button("Load the model", visible=not shared.settings['autoload_model']) unload = gr.Button("Unload the model") reload = gr.Button("Reload the model") save_settings = gr.Button("Save settings for this model") @@ -327,6 +335,9 @@ def create_model_menus(): with gr.Row(): with gr.Column(): + with gr.Row(): + shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown. You can change the default with a settings.json file.') + shared.gradio['custom_model_menu'] = gr.Textbox(label="Download custom model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main") shared.gradio['download_model_button'] = gr.Button("Download") @@ -335,12 +346,20 @@ def create_model_menus(): # In this event handler, the interface state is read and updated # with the model defaults (if any), and then the model is loaded + # unless "autoload_model" is unchecked shared.gradio['model_menu'].change( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( load_model_specific_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['interface_state']).then( ui.apply_interface_values, shared.gradio['interface_state'], [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then( update_model_parameters, shared.gradio['interface_state'], None).then( - load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True) + load_model_wrapper, [shared.gradio[k] for k in ['model_menu', 'autoload_model']], shared.gradio['model_status'], show_progress=False) + + load.click( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + ui.apply_interface_values, shared.gradio['interface_state'], + [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False).then( + update_model_parameters, shared.gradio['interface_state'], None).then( + partial(load_model_wrapper, autoload=True), shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False) unload.click( unload_model, None, None).then( @@ -358,6 +377,7 @@ def create_model_menus(): shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False) + shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load) def create_settings_menus(default_preset): diff --git a/settings-template.json b/settings-template.json index ebf751d7..f893389d 100644 --- a/settings-template.json +++ b/settings-template.json @@ -1,4 +1,5 @@ { + "autoload_model": true, "max_new_tokens": 200, "max_new_tokens_min": 1, "max_new_tokens_max": 2000,