From 0dd1409f24e21f25e695ed788f42703e615b7d77 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 11 Feb 2023 14:48:12 -0300 Subject: [PATCH] Add penalty_alpha parameter (contrastive search) --- README.md | 1 + presets/Instruct-Joi.txt | 5 +++++ server.py | 47 +++++++++++++++++++++------------------- 3 files changed, 31 insertions(+), 22 deletions(-) create mode 100644 presets/Instruct-Joi.txt diff --git a/README.md b/README.md index 468f74f3..5727f3ba 100644 --- a/README.md +++ b/README.md @@ -186,4 +186,5 @@ For these two, please try commenting on an existing issue instead of creating a - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets - Pygmalion preset, code for early stopping in chat mode, code for some of the sliders: https://github.com/PygmalionAI/gradio-ui/ - Verbose preset: Anonymous 4chan user. +- Instruct-Joi preset: https://huggingface.co/Rallio67/joi\_12B\_instruct\_alpha - Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui diff --git a/presets/Instruct-Joi.txt b/presets/Instruct-Joi.txt new file mode 100644 index 00000000..cc21ba11 --- /dev/null +++ b/presets/Instruct-Joi.txt @@ -0,0 +1,5 @@ +top_p=0.95, +temperature=0.5, +penalty_alpha=0.6, +top_k=4, +repetition_penalty=1.03, diff --git a/server.py b/server.py index e5df3056..20cd3572 100644 --- a/server.py +++ b/server.py @@ -174,6 +174,7 @@ def load_preset_values(preset_menu, return_dict=False): 'repetition_penalty': 1, 'top_k': 50, 'num_beams': 1, + 'penalty_alpha': 0, 'min_length': 0, 'length_penalty': 1, 'no_repeat_ngram_size': 0, @@ -191,7 +192,7 @@ def load_preset_values(preset_menu, return_dict=False): if return_dict: return generate_params else: - return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['length_penalty'], generate_params['early_stopping'] + return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] # Removes empty replies from gpt4chan outputs def fix_gpt4chan(s): @@ -237,7 +238,7 @@ def formatted_outputs(reply, model_name): else: return reply -def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=None, stopping_string=None): +def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): global model_name, model, tokenizer original_question = question @@ -274,6 +275,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top f"min_length={min_length if args.no_stream else 0}", f"no_repeat_ngram_size={no_repeat_ngram_size}", f"num_beams={num_beams}", + f"penalty_alpha={penalty_alpha}", f"length_penalty={length_penalty}", f"early_stopping={early_stopping}", ] @@ -392,6 +394,7 @@ def create_settings_menus(): repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty") top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k") no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size") + penalty_alpha = gr.Slider(0, 5, value=generate_params["penalty_alpha"], label="penalty_alpha") gr.Markdown("Special parameters (only use them if you really need them):") with gr.Row(): @@ -403,8 +406,8 @@ def create_settings_menus(): early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping") model_menu.change(load_model_wrapper, [model_menu], []) - preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping]) - return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping + preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping]) + return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping # This gets the new line characters right. def clean_chat_message(text): @@ -475,14 +478,14 @@ def extract_message_from_reply(question, reply, current, other, check, extension return reply, next_character_found, substring_found -def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size): +def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): original_text = text text = apply_extensions(text, "input") question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) history['internal'].append(['', '']) history['visible'].append(['', '']) eos_token = '\n' if check else None - for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): + for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name2, name1, check, extensions=True) history['internal'][-1] = [text, reply] history['visible'][-1] = [original_text, apply_extensions(reply, "output")] @@ -492,10 +495,10 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, break yield history['visible'] -def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size): +def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True) eos_token = '\n' if check else None - for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): + for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False) if not substring_found: yield apply_extensions(reply, "output") @@ -503,19 +506,19 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to break yield apply_extensions(reply, "output") -def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size): - for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size): +def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): + for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): yield generate_chat_html(_history, name1, name2, character) -def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size): +def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): last = history['visible'].pop() history['internal'].pop() text = last[0] if args.cai_chat: - for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size): + for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): yield i else: - for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size): + for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size): yield i def remove_last_message(name1, name2): @@ -775,7 +778,7 @@ if args.chat or args.cai_chat: with gr.Column(): history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size in prompt (0 for no limit)', value=settings['history_size']) - preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus() + preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name') name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name') @@ -813,7 +816,7 @@ if args.chat or args.cai_chat: if args.extensions is not None: create_extensions_block() - input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping, name1, name2, context, check, history_size_slider] + input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider] if args.cai_chat: gen_events.append(buttons["Generate"].click(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream, api_name="textgen")) gen_events.append(textbox.submit(cai_chatbot_wrapper, input_params, display, show_progress=args.no_stream)) @@ -860,13 +863,13 @@ elif args.notebook: max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) - preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus() + preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() if args.extensions is not None: create_extensions_block() - gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream)) + gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=args.no_stream)) buttons["Stop"].click(None, None, None, cancels=gen_events) else: @@ -883,7 +886,7 @@ else: with gr.Column(): buttons["Stop"] = gr.Button("Stop") - preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping = create_settings_menus() + preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() if args.extensions is not None: create_extensions_block() @@ -895,9 +898,9 @@ else: with gr.Tab('HTML'): html = gr.HTML() - gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")) - gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream)) - gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream)) + gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")) + gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream)) + gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=args.no_stream)) buttons["Stop"].click(None, None, None, cancels=gen_events) interface.queue()