diff --git a/modules/extensions.py b/modules/extensions.py index 1a936aa6..1ab07761 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,5 +1,6 @@ import extensions import modules.shared as shared +import gradio as gr extension_state = {} available_extensions = [] @@ -38,3 +39,26 @@ def update_extensions_parameters(*kwargs): def get_params(name): return eval(f"extensions.{name}.script.params") + +def create_extensions_block(): + extensions_ui_elements = [] + default_values = [] + if not (shared.args.chat or shared.args.cai_chat): + gr.Markdown('## Extensions parameters') + for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): + if extension_state[ext][0] == True: + params = get_params(ext) + for param in params: + _id = f"{ext}-{param}" + default_value = shared.settings[_id] if _id in shared.settings else params[param] + default_values.append(default_value) + if type(params[param]) == str: + extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}")) + elif type(params[param]) in [int, float]: + extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}")) + elif type(params[param]) == bool: + extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}")) + + update_extensions_parameters(*default_values) + btn_extensions = gr.Button("Apply") + btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], []) diff --git a/server.py b/server.py index d5439710..414b1dc6 100644 --- a/server.py +++ b/server.py @@ -14,7 +14,6 @@ import modules.chat as chat import modules.extensions as extensions_module import modules.shared as shared import modules.ui as ui -from modules.extensions import extension_state, load_extensions, update_extensions_parameters from modules.html_generator import generate_chat_html from modules.models import load_model, load_soft_prompt from modules.text_generation import generate_reply @@ -95,29 +94,6 @@ def upload_soft_prompt(file): return name -def create_extensions_block(): - extensions_ui_elements = [] - default_values = [] - if not (shared.args.chat or shared.args.cai_chat): - gr.Markdown('## Extensions parameters') - for ext in sorted(extension_state, key=lambda x : extension_state[x][1]): - if extension_state[ext][0] == True: - params = extensions_module.get_params(ext) - for param in params: - _id = f"{ext}-{param}" - default_value = shared.settings[_id] if _id in shared.settings else params[param] - default_values.append(default_value) - if type(params[param]) == str: - extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}")) - elif type(params[param]) in [int, float]: - extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}")) - elif type(params[param]) == bool: - extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}")) - - update_extensions_parameters(*default_values) - btn_extensions = gr.Button("Apply") - btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], []) - def create_settings_menus(): generate_params = load_preset_values(shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', return_dict=True) @@ -176,7 +152,7 @@ available_softprompts = get_available_softprompts() extensions_module.available_extensions = get_available_extensions() if shared.args.extensions is not None: - load_extensions() + extensions_module.load_extensions() # Choosing the default model if shared.args.model is not None: @@ -279,7 +255,7 @@ if shared.args.chat or shared.args.cai_chat: if shared.args.extensions is not None: with gr.Tab("Extensions"): - create_extensions_block() + extensions_module.create_extensions_block() input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider] if shared.args.picture: @@ -340,7 +316,7 @@ elif shared.args.notebook: preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() if shared.args.extensions is not None: - create_extensions_block() + extensions_module.create_extensions_block() gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen")) gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream)) @@ -362,7 +338,7 @@ else: preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() if shared.args.extensions is not None: - create_extensions_block() + extensions_module.create_extensions_block() with gr.Column(): with gr.Tab('Raw'):