Further refactor

This commit is contained in:
oobabooga 2023-02-23 14:31:28 -03:00
parent e46c43afa6
commit 364529d0c7

View file

@ -25,11 +25,27 @@ from modules.text_generation import generate_reply
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n")
# Loading custom settings
if shared.args.settings is not None and Path(shared.args.settings).exists():
new_settings = json.loads(open(Path(shared.args.settings), 'r').read())
for item in new_settings:
shared.settings[item] = new_settings[item]
def get_available_models():
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
def get_available_characters():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
def get_available_extensions():
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
def get_available_softprompts():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def load_model_wrapper(selected_model):
if selected_model != shared.model_name:
shared.model_name = selected_model
@ -82,21 +98,6 @@ def upload_soft_prompt(file):
return name
def get_available_models():
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
def get_available_characters():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
def get_available_extensions():
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
def get_available_softprompts():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def create_extensions_block():
extensions_ui_elements = []
default_values = []
@ -171,7 +172,6 @@ def create_settings_menus():
upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu])
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
# Global variables
available_models = get_available_models()
available_presets = get_available_presets()
available_characters = get_available_characters()
@ -192,7 +192,7 @@ else:
i = 0
else:
print("The following models are available:\n")
for i,model in enumerate(available_models):
for i, model in enumerate(available_models):
print(f"{i+1}. {model}")
print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
i = int(input())-1
@ -201,20 +201,18 @@ else:
shared.model, shared.tokenizer = load_model(shared.model_name)
# UI settings
buttons = {}
gen_events = []
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
default_text = shared.settings['prompt_gpt4chan']
elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None:
default_text = 'User: \n'
else:
default_text = shared.settings['prompt']
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
buttons = {}
gen_events = []
if shared.args.chat or shared.args.cai_chat:
if Path(f'logs/persistent.json').exists():
chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'])