From 959feba602c2b447159694c5a3c0e77fefaf362d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 1 Aug 2023 06:10:09 -0700 Subject: [PATCH] When saving model settings, only save the settings for the current loader --- modules/models_settings.py | 10 ++++++---- server.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/modules/models_settings.py b/modules/models_settings.py index 9319582e..00a6b90f 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -3,7 +3,7 @@ from pathlib import Path import yaml -from modules import shared, ui +from modules import loaders, shared, ui def get_model_settings_from_yamls(model): @@ -126,10 +126,12 @@ def save_model_settings(model, state): user_config[model_regex] = {} for k in ui.list_model_elements(): - user_config[model_regex][k] = state[k] - shared.model_config[model_regex][k] = state[k] + if k == 'loader' or k in loaders.loaders_and_params[state['loader']]: + user_config[model_regex][k] = state[k] + shared.model_config[model_regex][k] = state[k] + output = yaml.dump(user_config, sort_keys=False) with open(p, 'w') as f: - f.write(yaml.dump(user_config, sort_keys=False)) + f.write(output) yield (f"Settings for {model} saved to {p}") diff --git a/server.py b/server.py index 86e52466..4a757664 100644 --- a/server.py +++ b/server.py @@ -220,8 +220,8 @@ def create_model_menus(): shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx) shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads) shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch) - shared.gradio['n_gqa'] = gr.Slider(minimum=0, maximum=16, step=1, label="n_gqa", value=shared.args.n_gqa, info='grouped-query attention. Must be 8 for llama2 70b.') - shared.gradio['rms_norm_eps'] = gr.Slider(minimum=0, maximum=1e-5, step=1e-6, label="rms_norm_eps", value=shared.args.n_gqa, info='5e-6 is a good value for llama2 70b.') + shared.gradio['n_gqa'] = gr.Slider(minimum=0, maximum=16, step=1, label="n_gqa", value=shared.args.n_gqa, info='grouped-query attention. Must be 8 for llama-2 70b.') + shared.gradio['rms_norm_eps'] = gr.Slider(minimum=0, maximum=1e-5, step=1e-6, label="rms_norm_eps", value=shared.args.n_gqa, info='5e-6 is a good value for llama-2 models.') shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None") shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")