From ac189011cb0503cccb58c2c3c8cabe54916de7b4 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 15 Apr 2023 12:54:02 -0300 Subject: [PATCH] Add "Save current settings for this model" button --- .gitignore | 1 + modules/models.py | 15 ++++++--------- server.py | 36 +++++++++++++++++++++++++++++++++--- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 2d007efe..12f94ba0 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ settings.json img_bot* img_me* prompts/[0-9]* +models/config-user.yaml diff --git a/modules/models.py b/modules/models.py index 3467f4f2..2a9007e0 100644 --- a/modules/models.py +++ b/modules/models.py @@ -45,17 +45,14 @@ def load_model(model_name): shared.is_RWKV = 'rwkv-' in model_name.lower() shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0 - # Default settings + # Load the model in simple 16-bit mode by default if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]): - if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): - model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + if torch.has_mps: + device = torch.device('mps') + model = model.to(device) else: - model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) - if torch.has_mps: - device = torch.device('mps') - model = model.to(device) - else: - model = model.cuda() + model = model.cuda() # FlexGen elif shared.args.flexgen: diff --git a/server.py b/server.py index fee8902a..ecc89320 100644 --- a/server.py +++ b/server.py @@ -21,6 +21,7 @@ from pathlib import Path import gradio as gr import psutil import torch +import yaml from PIL import Image import modules.extensions as extensions_module @@ -233,7 +234,7 @@ def get_model_specific_settings(model): model_settings = {} for pat in settings: - if re.match(pat, model.lower()): + if re.match(pat.lower(), model.lower()): for k in settings[pat]: model_settings[k] = settings[pat][k] @@ -249,6 +250,29 @@ def load_model_specific_settings(model, state, return_dict=False): return state +def save_model_settings(model, state): + if model == 'None': + yield ("Not saving the settings because no model is loaded.") + return + + with Path(f'{shared.args.model_dir}/config-user.yaml') as p: + if p.exists(): + user_config = yaml.safe_load(open(p, 'r').read()) + else: + user_config = {} + + if model not in user_config: + user_config[model] = {} + + for k in ui.list_model_elements(): + user_config[model][k] = state[k] + + with open(p, 'w') as f: + f.write(yaml.dump(user_config)) + + yield (f"Settings for {model} saved to {p}") + + def create_model_menus(): # Finding the default values for the GPU and CPU memories total_mem = [] @@ -285,10 +309,12 @@ def create_model_menus(): ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button') with gr.Column(): - shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs') + with gr.Row(): + shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs') with gr.Row(): unload = gr.Button("Unload the model") reload = gr.Button("Reload the model") + save_settings = gr.Button("Save current settings for this model") with gr.Row(): with gr.Column(): @@ -344,7 +370,11 @@ def create_model_menus(): unload_model, None, None).then( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( update_model_parameters, shared.gradio['interface_state'], None).then( - load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True) + load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False) + + save_settings.click( + ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( + save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)