diff --git a/characters/instruction-following/Alpaca.yaml b/characters/instruction-following/Alpaca.yaml new file mode 100644 index 00000000..30373242 --- /dev/null +++ b/characters/instruction-following/Alpaca.yaml @@ -0,0 +1,3 @@ +name: "### Response:" +your_name: "### Instruction:" +context: "Below is an instruction that describes a task. Write a response that appropriately completes the request." diff --git a/characters/instruction-following/Open Assistant.yaml b/characters/instruction-following/Open Assistant.yaml new file mode 100644 index 00000000..5b3320ff --- /dev/null +++ b/characters/instruction-following/Open Assistant.yaml @@ -0,0 +1,3 @@ +name: "<|assistant|>" +your_name: "<|prompter|>" +end_of_turn: "<|endoftext|>" diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css new file mode 100644 index 00000000..f50b64d4 --- /dev/null +++ b/css/html_instruct_style.css @@ -0,0 +1,56 @@ +.chat { + margin-left: auto; + margin-right: auto; + max-width: 800px; + height: 66.67vh; + overflow-y: auto; + padding-right: 20px; + display: flex; + flex-direction: column-reverse; +} + +.message { + display: grid; + grid-template-columns: 60px 1fr; + padding-bottom: 25px; + font-size: 15px; + font-family: Helvetica, Arial, sans-serif; + line-height: 1.428571429; +} + +.text p { + margin-top: 5px; +} + +.username { + display: none; +} + +.message-body {} + +.message-body p { + margin-bottom: 0 !important; + font-size: 15px !important; + line-height: 1.428571429 !important; +} + +.dark .message-body p em { + color: rgb(138, 138, 138) !important; +} + +.message-body p em { + color: rgb(110, 110, 110) !important; +} + +.assistant-message { + padding: 10px; +} + +.user-message { + padding: 10px; + background-color: #f1f1f1; +} + +.dark .user-message { + background-color: #ffffff1a; +} \ No newline at end of file diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py index cc85f3b3..df07ef2d 100644 --- a/extensions/sd_api_pictures/script.py +++ b/extensions/sd_api_pictures/script.py @@ -176,4 +176,4 @@ def ui(): force_btn.click(force_pic) generate_now_btn.click(force_pic) - generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) \ No newline at end of file + generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) \ No newline at end of file diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index 556a88e5..b6305bdc 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -36,13 +36,11 @@ def generate_chat_picture(picture, name1, name2): def ui(): picture_select = gr.Image(label='Send a picture', type='pil') - function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - # Prepare the hijack with custom inputs picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None) # Call the generation function - picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) + picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) # Clear the picture from the upload field picture_select.upload(lambda : None, [], [picture_select], show_progress=False) diff --git a/modules/chat.py b/modules/chat.py index 21d9d16d..978a08f2 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -12,46 +12,51 @@ from PIL import Image import modules.extensions as extensions_module import modules.shared as shared from modules.extensions import apply_extensions -from modules.html_generator import (fix_newlines, generate_chat_html, +from modules.html_generator import (fix_newlines, chat_html_wrapper, make_thumbnail) from modules.text_generation import (encode, generate_reply, get_max_prompt_length) -def generate_chat_output(history, name1, name2): - if shared.args.cai_chat: - return generate_chat_html(history, name1, name2) - else: - return history - -def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False): +def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn="", impersonate=False, also_return_rows=False): user_input = fix_newlines(user_input) rows = [f"{context.strip()}\n"] + # Finding the maximum prompt size if shared.soft_prompt: chat_prompt_size -= shared.soft_prompt_tensor.shape[1] max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size) + if is_instruct: + prefix1 = f"{name1}\n" + prefix2 = f"{name2}\n" + else: + prefix1 = f"{name1}: " + prefix2 = f"{name2}: " + i = len(shared.history['internal'])-1 while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: - rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") - prev_user_input = shared.history['internal'][i][0] - if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']: - rows.insert(1, f"{name1}: {prev_user_input.strip()}\n") + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") + string = shared.history['internal'][i][0] + if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: + rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n") i -= 1 - if not impersonate: - if len(user_input) > 0: - rows.append(f"{name1}: {user_input}\n") - rows.append(apply_extensions(f"{name2}:", "bot_prefix")) - limit = 3 - else: - rows.append(f"{name1}:") + if impersonate: + rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") limit = 2 + else: + + # Adding the user message + if len(user_input) > 0: + rows.append(f"{prefix1}{user_input}{end_of_turn}\n") + + # Adding the Character prefix + rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) + limit = 3 while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: rows.pop(1) - prompt = ''.join(rows) if also_return_rows: @@ -86,7 +91,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): reply = fix_newlines(reply) return reply, next_character_found -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False, mode="cai-chat", end_of_turn=""): just_started = True eos_token = '\n' if stop_at_newline else None name1_original = name1 @@ -105,14 +110,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical if visible_text is None: visible_text = text - if shared.args.chat: - visible_text = visible_text.replace('\n', '
') text = apply_extensions(text, "input") + is_instruct = mode == 'instruct' if custom_generate_chat_prompt is None: - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn) else: - prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn) # Yield *Is typing...* if not regenerate: @@ -129,8 +133,6 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) visible_reply = re.sub("(||{{user}})", name1_original, reply) visible_reply = apply_extensions(visible_reply, "output") - if shared.args.chat: - visible_reply = visible_reply.replace('\n', '
') # We need this global variable to handle the Stop event, # otherwise gradio gets confused @@ -153,13 +155,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical yield shared.history['visible'] -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): +def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): eos_token = '\n' if stop_at_newline else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True, end_of_turn=end_of_turn) # Yield *Is typing...* yield shared.processing_message @@ -179,36 +181,30 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): - for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts): - yield generate_chat_html(history, name1, name2) +def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): + for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=False, mode=mode, end_of_turn=end_of_turn): + yield chat_html_wrapper(history, name1, name2, mode) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): +def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: - yield generate_chat_output(shared.history['visible'], name1, name2) + yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) else: last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' - yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2) - for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True): - if shared.args.cai_chat: - shared.history['visible'][-1] = [last_visible[0], history[-1][1]] - else: - shared.history['visible'][-1] = (last_visible[0], history[-1][1]) - yield generate_chat_output(shared.history['visible'], name1, name2) + yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode) + for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True, mode=mode, end_of_turn=end_of_turn): + shared.history['visible'][-1] = [last_visible[0], history[-1][1]] + yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) -def remove_last_message(name1, name2): +def remove_last_message(name1, name2, mode): if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': last = shared.history['visible'].pop() shared.history['internal'].pop() else: last = ['', ''] - if shared.args.cai_chat: - return generate_chat_html(shared.history['visible'], name1, name2), last[0] - else: - return shared.history['visible'], last[0] + return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0] def send_last_reply_to_input(): if len(shared.history['internal']) > 0: @@ -216,20 +212,17 @@ def send_last_reply_to_input(): else: return '' -def replace_last_reply(text, name1, name2): +def replace_last_reply(text, name1, name2, mode): if len(shared.history['visible']) > 0: - if shared.args.cai_chat: - shared.history['visible'][-1][1] = text - else: - shared.history['visible'][-1] = (shared.history['visible'][-1][0], text) + shared.history['visible'][-1][1] = text shared.history['internal'][-1][1] = apply_extensions(text, "input") - return generate_chat_output(shared.history['visible'], name1, name2) + return chat_html_wrapper(shared.history['visible'], name1, name2, mode) def clear_html(): - return generate_chat_html([], "", "") + return chat_html_wrapper([], "", "") -def clear_chat_log(name1, name2, greeting): +def clear_chat_log(name1, name2, greeting, mode): shared.history['visible'] = [] shared.history['internal'] = [] @@ -237,12 +230,12 @@ def clear_chat_log(name1, name2, greeting): shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - return generate_chat_output(shared.history['visible'], name1, name2) + return chat_html_wrapper(shared.history['visible'], name1, name2, mode) -def redraw_html(name1, name2): - return generate_chat_html(shared.history['visible'], name1, name2) +def redraw_html(name1, name2, mode): + return chat_html_wrapper(shared.history['visible'], name1, name2, mode) -def tokenize_dialogue(dialogue, name1, name2): +def tokenize_dialogue(dialogue, name1, name2, mode): history = [] dialogue = re.sub('', '', dialogue) @@ -339,11 +332,12 @@ def generate_pfp_cache(character): return img return None -def load_character(character, name1, name2): +def load_character(character, name1, name2, instruct=False): shared.character = character shared.history['internal'] = [] shared.history['visible'] = [] - greeting = "" + context = greeting = end_of_turn = "" + greeting_field = 'greeting' picture = None # Deleting the profile picture cache, if any @@ -351,9 +345,10 @@ def load_character(character, name1, name2): Path("cache/pfp_character.png").unlink() if character != 'None': + folder = "characters" if not instruct else "characters/instruction-following" picture = generate_pfp_cache(character) for extension in ["yml", "yaml", "json"]: - filepath = Path(f'characters/{character}.{extension}') + filepath = Path(f'{folder}/{character}.{extension}') if filepath.exists(): break file_contents = open(filepath, 'r', encoding='utf-8').read() @@ -369,19 +364,21 @@ def load_character(character, name1, name2): if 'context' in data: context = f"{data['context'].strip()}\n\n" - greeting_field = 'greeting' - else: + elif "char_persona" in data: context = build_pygmalion_style_context(data) greeting_field = 'char_greeting' - if 'example_dialogue' in data and data['example_dialogue'] != '': + if 'example_dialogue' in data: context += f"{data['example_dialogue'].strip()}\n" - if greeting_field in data and len(data[greeting_field].strip()) > 0: + if greeting_field in data: greeting = data[greeting_field] + if 'end_of_turn' in data: + end_of_turn = data['end_of_turn'] else: context = shared.settings['context'] name2 = shared.settings['name2'] greeting = shared.settings['greeting'] + end_of_turn = shared.settings['end_of_turn'] if Path(f'logs/{shared.character}_persistent.json').exists(): load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) @@ -389,10 +386,7 @@ def load_character(character, name1, name2): shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - if shared.args.cai_chat: - return name1, name2, picture, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, reset_cache=True) - else: - return name1, name2, picture, greeting, context, shared.history['visible'] + return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, reset_cache=True) def load_default_history(name1, name2): load_character("None", name1, name2) @@ -423,7 +417,7 @@ def upload_tavern_character(img, name1, name2): _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} return upload_character(json.dumps(_json), img, tavern=True) -def upload_your_profile_picture(img, name1, name2): +def upload_your_profile_picture(img, name1, name2, mode): cache_folder = Path("cache") if not cache_folder.exists(): cache_folder.mkdir() @@ -436,7 +430,4 @@ def upload_your_profile_picture(img, name1, name2): img.save(Path('cache/pfp_me.png')) print('Profile picture saved to "cache/pfp_me.png"') - if shared.args.cai_chat: - return generate_chat_html(shared.history['visible'], name1, name2, reset_cache=True) - else: - return shared.history['visible'] + return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) diff --git a/modules/html_generator.py b/modules/html_generator.py index e1c085a6..6fb8457f 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -21,6 +21,8 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') _4chan_css = css_f.read() with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f: cai_css = f.read() +with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f: + instruct_css = f.read() def fix_newlines(string): string = string.replace('\n', '\n\n') @@ -117,7 +119,39 @@ def get_image_cache(path): return image_cache[path][1] -def generate_chat_html(history, name1, name2, reset_cache=False): +def generate_instruct_html(history): + output = f'
' + for i,_row in enumerate(history[::-1]): + row = [convert_to_markdown(entry) for entry in _row] + + output += f""" +
+
+
+ {row[1]} +
+
+
+ """ + + if len(row[0]) == 0: # don't display empty user messages + continue + + output += f""" +
+
+
+ {row[0]} +
+
+
+ """ + + output += "
" + + return output + +def generate_cai_chat_html(history, name1, name2, reset_cache=False): output = f'
' # The time.time() is to prevent the brower from caching the image @@ -165,3 +199,17 @@ def generate_chat_html(history, name1, name2, reset_cache=False): output += "
" return output + +def generate_chat_html(history, name1, name2): + return generate_cai_chat_html(history, name1, name2) + +def chat_html_wrapper(history, name1, name2, mode="cai-chat", reset_cache=False): + + if mode == "cai-chat": + return generate_cai_chat_html(history, name1, name2, reset_cache) + elif mode == "chat": + return generate_chat_html(history, name1, name2) + elif mode == "instruct": + return generate_instruct_html(history) + else: + return '' diff --git a/modules/shared.py b/modules/shared.py index 6c183a81..a6f58778 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -33,6 +33,7 @@ settings = { 'name2': 'Assistant', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'greeting': 'Hello there!', + 'end_of_turn': '', 'stop_at_newline': False, 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, @@ -73,8 +74,8 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma # Basic settings parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') -parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') -parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') +parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') +parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.') parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") @@ -131,12 +132,17 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent args = parser.parse_args() -# Provisional, this will be deleted later +# Deprecation warnings for parameters that have been renamed deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]} for k in deprecated_dict: if eval(f"args.{k}") != deprecated_dict[k][1]: print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.") exec(f"args.{deprecated_dict[k][0]} = args.{k}") +# Deprecation warnings for parameters that have been removed +if args.cai_chat: + print("Warning: --cai-chat is deprecated. Use --chat instead.") + args.chat = True + def is_chat(): - return any((args.chat, args.cai_chat)) + return args.chat diff --git a/server.py b/server.py index a34c86f7..f367ca0d 100644 --- a/server.py +++ b/server.py @@ -12,7 +12,7 @@ from PIL import Image import modules.extensions as extensions_module from modules import chat, shared, training, ui -from modules.html_generator import generate_chat_html +from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt from modules.text_generation import (clear_torch_cache, generate_reply, @@ -48,6 +48,10 @@ def get_available_prompts(): def get_available_characters(): paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) + return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower) + +def get_available_instruction_templates(): + paths = (x for x in Path('characters/instruction-following').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower) def get_available_extensions(): @@ -145,7 +149,7 @@ def load_prompt(fname): if text[-1] == '\n': text = text[:-1] return text - + def create_prompt_menus(): with gr.Row(): with gr.Column(): @@ -296,10 +300,7 @@ def create_interface(): if shared.is_chat(): shared.gradio['Chat input'] = gr.State() with gr.Tab("Text generation", elem_id="main"): - if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'])) - else: - shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot") + shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'])) shared.gradio['textbox'] = gr.Textbox(label='Input') with gr.Row(): shared.gradio['Generate'] = gr.Button('Generate') @@ -316,13 +317,17 @@ def create_interface(): shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) + shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode") + shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False) + with gr.Tab("Character", elem_id="chat-settings"): with gr.Row(): with gr.Column(scale=8): shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') - shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting') - shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context') + shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting') + shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context') + shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string') with gr.Column(scale=1): shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil") shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None) @@ -367,31 +372,31 @@ def create_interface(): create_settings_menus(default_preset) - function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] + shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts', 'Chat mode', 'end_of_turn']] def set_chat_input(textbox): return textbox, "" gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False)) - gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False)) - gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) - shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) + shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream) # Clear history with confirmation clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) - shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display']) + shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']) shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) + shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']) - shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) + shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) @@ -404,18 +409,20 @@ def create_interface(): shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) - shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) + shared.gradio['Instruction templates'].change(lambda character, name1, name2: chat.load_character(character, name1, name2, instruct=True), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) + shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) - shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2']], shared.gradio['display']) + shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display']) - reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] - reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] - shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) - shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) + reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']] + shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']]) + shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']]) + shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']]) + shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) - shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) + shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: with gr.Tab("Text generation", elem_id="main"):