Precise prompts for instruct mode

This commit is contained in:
oobabooga 2023-04-26 03:21:53 -03:00 committed by GitHub
parent a8409426d7
commit a777c058af
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 71 additions and 34 deletions

View file

@ -1,3 +1,4 @@
name: "### Response:"
your_name: "### Instruction:"
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
turn_template: "<|user|>\n<|user-message|>\n\n<|bot|>\n<|bot-message|>\n\n"

View file

@ -1,3 +1,4 @@
name: "答:"
your_name: "[Round <|round|>]\n问:"
context: ""
turn_template: "<|user|><|user-message|>\n<|bot|><|bot-message|>\n"

View file

@ -1,3 +1,4 @@
name: "GPT:"
your_name: "USER:"
context: "BEGINNING OF CONVERSATION:"
context: "BEGINNING OF CONVERSATION: "
turn_template: "<|user|> <|user-message|> <|bot|><|bot-message|></s>"

View file

@ -1,3 +1,4 @@
name: "### Assistant"
your_name: "### Human"
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
turn_template: "<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n"

View file

@ -1,3 +1,3 @@
name: "<|assistant|>"
your_name: "<|prompter|>"
end_of_turn: "<|endoftext|>"
turn_template: "<|user|><|user-message|><|endoftext|><|bot|><|bot-message|><|endoftext|>"

View file

@ -0,0 +1,3 @@
name: "Alice:"
your_name: "Bob:"
turn_template: "<|user|> <|user-message|>\n\n<|bot|><|bot-message|>\n\n"

View file

@ -0,0 +1,4 @@
name: "### Assistant:"
your_name: "### Human:"
context: "A chat between a human and an assistant.\n\n"
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|>\n"

View file

@ -1,3 +1,4 @@
name: "### Assistant:"
your_name: "### Human:"
context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
name: "ASSISTANT:"
your_name: "USER:"
context: "A chat between a user and an assistant.\n\n"
turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|></s>\n"

View file

@ -52,3 +52,6 @@ llama-[0-9]*b-4bit$:
mode: 'instruct'
model_type: 'llama'
instruction_template: 'LLaVA'
.*raven:
mode: 'instruct'
instruction_template: 'RWKV-Raven'

View file

@ -17,12 +17,20 @@ from modules.text_generation import (encode, generate_reply,
get_max_prompt_length)
# Replace multiple string pairs in a string
def replace_all(text, dic):
for i, j in dic.items():
text = text.replace(i, j)
return text
def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
is_instruct = state['mode'] == 'instruct'
rows = [f"{state['context'].strip()}\n"]
rows = [state['context'] if is_instruct else f"{state['context'].strip()}\n"]
min_rows = 3
# Finding the maximum prompt size
@ -31,38 +39,50 @@ def generate_chat_prompt(user_input, state, **kwargs):
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
if is_instruct:
prefix1 = f"{state['name1']}\n"
prefix2 = f"{state['name2']}\n"
else:
prefix1 = f"{state['name1']}: "
prefix2 = f"{state['name2']}: "
# Building the turn templates
if 'turn_template' not in state or state['turn_template'] == '':
if is_instruct:
template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n'
else:
template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n'
else:
template = state['turn_template'].replace(r'\n', '\n')
replacements = {
'<|user|>': state['name1'].strip(),
'<|bot|>': state['name2'].strip(),
}
user_turn = replace_all(template.split('<|bot|>')[0], replacements)
bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements)
user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements)
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
# Building the prompt
i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
this_prefix1 = prefix1.replace('<|round|>', f'{i}') # for ChatGLM
rows.insert(1, f"{this_prefix1}{string.strip()}{state['end_of_turn']}\n")
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
i -= 1
if impersonate:
min_rows = 2
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
rows.append(user_turn_stripped)
elif not _continue:
# Adding the user message
if len(user_input) > 0:
this_prefix1 = prefix1.replace('<|round|>', f'{len(shared.history["internal"])}') # for ChatGLM
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}"))
rows.append(apply_extensions("bot_prefix", bot_turn_stripped))
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
rows.pop(1)
@ -416,7 +436,7 @@ def generate_pfp_cache(character):
def load_character(character, name1, name2, mode):
shared.character = character
context = greeting = end_of_turn = ""
context = greeting = turn_template = ""
greeting_field = 'greeting'
picture = None
@ -445,7 +465,9 @@ def load_character(character, name1, name2, mode):
data[field] = replace_character_names(data[field], name1, name2)
if 'context' in data:
context = f"{data['context'].strip()}\n\n"
context = data['context']
if mode != 'instruct':
context = context.strip() + '\n\n'
elif "char_persona" in data:
context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting'
@ -456,14 +478,14 @@ def load_character(character, name1, name2, mode):
if greeting_field in data:
greeting = data[greeting_field]
if 'end_of_turn' in data:
end_of_turn = data['end_of_turn']
if 'turn_template' in data:
turn_template = data['turn_template']
else:
context = shared.settings['context']
name2 = shared.settings['name2']
greeting = shared.settings['greeting']
end_of_turn = shared.settings['end_of_turn']
turn_template = shared.settings['turn_template']
if mode != 'instruct':
shared.history['internal'] = []
@ -479,7 +501,7 @@ def load_character(character, name1, name2, mode):
# Create .json log files since they don't already exist
save_history(mode)
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode)
return name1, name2, picture, greeting, context, repr(turn_template)[1:-1], chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def upload_character(json_file, img, tavern=False):

View file

@ -39,7 +39,7 @@ settings = {
'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
'greeting': '',
'end_of_turn': '',
'turn_template': '',
'custom_stopping_strings': '',
'stop_at_newline': False,
'add_bos_token': True,

View file

@ -35,7 +35,7 @@ def list_model_elements():
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
elements += ['name1', 'name2', 'greeting', 'context', 'turn_template', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
elements += list_model_elements()
return elements

View file

@ -553,7 +553,7 @@ def create_interface():
shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings['end_of_turn'], lines=1, label='End of turn string')
shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.')
with gr.Column(scale=1):
shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
@ -778,7 +778,7 @@ def create_interface():
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['instruction_template'].change(
chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'turn_template', 'display']]).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['upload_chat_history'].upload(
@ -791,7 +791,7 @@ def create_interface():
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'turn_template', 'display']])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")

View file

@ -8,7 +8,7 @@
"name2": "Assistant",
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
"greeting": "",
"end_of_turn": "",
"turn_template": "",
"custom_stopping_strings": "",
"stop_at_newline": false,
"add_bos_token": true,