The soft prompt length must be considered here too

This commit is contained in:
oobabooga 2023-02-17 12:35:30 -03:00
parent a6ddbbfc77
commit 596732a981

View file

@ -505,11 +505,17 @@ def clean_chat_message(text):
return text return text
def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False): def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
global soft_prompt, soft_prompt_tensor
text = clean_chat_message(text) text = clean_chat_message(text)
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
i = len(history['internal'])-1 i = len(history['internal'])-1
count = 0 count = 0
if soft_prompt:
chat_prompt_size -= soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(tokens), chat_prompt_size) max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n") rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
count += 1 count += 1