Generalize superbooga to chat mode

This commit is contained in:
oobabooga 2023-05-07 15:01:14 -03:00
parent ec1cda0e1f
commit 6b67cb6611
2 changed files with 65 additions and 17 deletions

View file

@ -8,9 +8,10 @@ import posthog
import torch
from bs4 import BeautifulSoup
from chromadb.config import Settings
from modules import shared
from sentence_transformers import SentenceTransformer
from modules import chat, shared
print('Intercepting all calls to posthog :)')
posthog.capture = lambda *args, **kwargs: None
@ -53,6 +54,10 @@ class ChromaCollector(Collecter):
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['documents'][0]
return result
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x : int(x[2:]), result))
def clear(self):
self.collection.delete(ids=self.ids)
@ -68,18 +73,24 @@ collector = ChromaCollector(embedder)
chunk_count = 5
def feed_data_into_collector(corpus, chunk_len):
def add_chunks_to_collector(chunks):
global collector
chunk_len = int(chunk_len)
collector.clear()
collector.add(chunks)
def feed_data_into_collector(corpus, chunk_len):
# Defining variables
chunk_len = int(chunk_len)
cumulative = ''
# Breaking the data into chunks and adding those to the db
cumulative += "Breaking the input dataset...\n\n"
yield cumulative
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
yield cumulative
collector.clear()
collector.add(data_chunks)
add_chunks_to_collector(data_chunks)
cumulative += "Done."
yield cumulative
@ -123,6 +134,8 @@ def apply_settings(_chunk_count):
def input_modifier(string):
if shared.is_chat():
return string
# Find the user input
pattern = re.compile(r"<\|begin-user-input\|>(.*?)<\|end-user-input\|>", re.DOTALL)
@ -143,6 +156,23 @@ def input_modifier(string):
return string
def custom_generate_chat_prompt(user_input, state, **kwargs):
if len(shared.history['internal']) > 2 and user_input != '':
chunks = []
for i in range(len(shared.history['internal'])-1):
chunks.append('\n'.join(shared.history['internal'][i]))
add_chunks_to_collector(chunks)
query = '\n'.join(shared.history['internal'][-1] + [user_input])
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
# Sort the history by relevance instead of by chronological order,
# except for the latest message
state['history'] = [shared.history['internal'][id_] for id_ in best_ids[::-1]] + [shared.history['internal'][-1]]
return chat.generate_chat_prompt(user_input, state, **kwargs)
def ui():
with gr.Accordion("Click for more information...", open=False):
gr.Markdown(textwrap.dedent("""
@ -156,7 +186,9 @@ def ui():
It is a modified version of the superbig extension by kaiokendev: https://github.com/kaiokendev/superbig
## How to use it
## Notebook/default modes
### How to use it
1) Paste your input text (of whatever length) into the text box below.
2) Click on "Load data" to feed this text into the Chroma database.
@ -166,7 +198,7 @@ def ui():
The special tokens mentioned above (`<|begin-user-input|>`, `<|end-user-input|>`, and `<|injection-point|>`) are removed when the injection happens.
## Example
### Example
For your convenience, you can use the following prompt as a starting point (for Alpaca models):
@ -186,14 +218,25 @@ def ui():
### Response:
```
## Chat mode
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
That is, the prompt will include (starting from the end):
* Your input
* The latest input/reply pair
* The #1 most relevant input/reply pair prior to the latest
* The #2 most relevant input/reply pair prior to the latest
* Etc
This way, the bot can have a long term history.
*This extension is currently experimental and under development.*
"""))
if shared.is_chat():
# Chat mode has to be handled differently, probably using a custom_generate_chat_prompt
pass
else:
if not shared.is_chat():
with gr.Row():
with gr.Column():
with gr.Tab("Text input"):

View file

@ -27,6 +27,11 @@ def replace_all(text, dic):
def generate_chat_prompt(user_input, state, **kwargs):
# Check if an extension is sending its modified history.
# If not, use the regular history
history = state['history'] if 'history' in state else shared.history['internal']
# Define some variables
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
@ -61,14 +66,14 @@ def generate_chat_prompt(user_input, state, **kwargs):
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
# Building the prompt
i = len(shared.history['internal']) - 1
i = len(history) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
if _continue and i == len(history) - 1:
rows.insert(1, bot_turn_stripped + history[i][1].strip())
else:
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
rows.insert(1, bot_turn.replace('<|bot-message|>', history[i][1].strip()))
string = shared.history['internal'][i][0]
string = history[i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
@ -80,7 +85,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
elif not _continue:
# Adding the user message
if len(user_input) > 0:
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))