Add superbooga time weighted history retrieval (#2080)

This commit is contained in:
Luis Lopez 2023-05-25 21:22:45 +08:00 committed by GitHub
parent a04266161d
commit ee674afa50
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 21 deletions

View file

@ -47,34 +47,58 @@ class ChromaCollector(Collecter):
self.ids = [f"id{i}" for i in range(len(texts))]
self.collection.add(documents=texts, ids=self.ids)
def get_documents_and_ids(self, search_strings: list[str], n_results: int):
def get_documents_ids_distances(self, search_strings: list[str], n_results: int):
n_results = min(len(self.ids), n_results)
if n_results == 0:
return [], []
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents', 'distances'])
documents = result['documents'][0]
ids = list(map(lambda x: int(x[2:]), result['ids'][0]))
return documents, ids
distances = result['distances'][0]
return documents, ids, distances
# Get chunks by similarity
def get(self, search_strings: list[str], n_results: int) -> list[str]:
documents, _ = self.get_documents_and_ids(search_strings, n_results)
documents, _, _ = self.get_documents_ids_distances(search_strings, n_results)
return documents
# Get ids by similarity
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
_, ids = self.get_documents_and_ids(search_strings, n_results)
_, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
return ids
# Get chunks by similarity and then sort by insertion order
def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
documents, ids = self.get_documents_and_ids(search_strings, n_results)
documents, ids, _ = self.get_documents_ids_distances(search_strings, n_results)
return [x for _, x in sorted(zip(ids, documents))]
# Multiply distance by factor within [0, time_weight] where more recent is lower
def apply_time_weight_to_distances(self, ids: list[int], distances: list[float], time_weight: float = 1.0) -> list[float]:
if len(self.ids) <= 1:
return distances.copy()
return [distance * (1 - _id / (len(self.ids) - 1) * time_weight) for _id, distance in zip(ids, distances)]
# Get ids by similarity and then sort by insertion order
def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]:
_, ids = self.get_documents_and_ids(search_strings, n_results)
def get_ids_sorted(self, search_strings: list[str], n_results: int, n_initial: int = None, time_weight: float = 1.0) -> list[str]:
do_time_weight = time_weight > 0
if not (do_time_weight and n_initial is not None):
n_initial = n_results
elif n_initial == -1:
n_initial = len(self.ids)
if n_initial < n_results:
raise ValueError(f"n_initial {n_initial} should be >= n_results {n_results}")
_, ids, distances = self.get_documents_ids_distances(search_strings, n_initial)
if do_time_weight:
distances_w = self.apply_time_weight_to_distances(ids, distances, time_weight=time_weight)
results = zip(ids, distances, distances_w)
results = sorted(results, key=lambda x: x[2])[:n_results]
results = sorted(results, key=lambda x: x[0])
ids = [x[0] for x in results]
return sorted(ids)
def clear(self):

View file

@ -12,6 +12,8 @@ from .download_urls import download_urls
params = {
'chunk_count': 5,
'chunk_count_initial': 10,
'time_weight': 0,
'chunk_length': 700,
'chunk_separator': '',
'strong_cleanup': False,
@ -20,7 +22,6 @@ params = {
collector = make_collector()
chat_collector = make_collector()
chunk_count = 5
def feed_data_into_collector(corpus, chunk_len, chunk_sep):
@ -83,13 +84,12 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
yield i
def apply_settings(_chunk_count):
global chunk_count
chunk_count = int(_chunk_count)
settings_to_display = {
'chunk_count': chunk_count,
}
def apply_settings(chunk_count, chunk_count_initial, time_weight):
global params
params['chunk_count'] = int(chunk_count)
params['chunk_count_initial'] = int(chunk_count_initial)
params['time_weight'] = time_weight
settings_to_display = {k: params[k] for k in params if k in ['chunk_count', 'chunk_count_initial', 'time_weight']}
yield f"The following settings are now active: {str(settings_to_display)}"
@ -97,7 +97,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
global chat_collector
if state['mode'] == 'instruct':
results = collector.get_sorted(user_input, n_results=chunk_count)
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
user_input += additional_context
else:
@ -108,7 +108,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
return output
if len(shared.history['internal']) > chunk_count and user_input != '':
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
chunks = []
hist_size = len(shared.history['internal'])
for i in range(hist_size-1):
@ -117,7 +117,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
add_chunks_to_collector(chunks, chat_collector)
query = '\n'.join(shared.history['internal'][-1] + [user_input])
try:
best_ids = chat_collector.get_ids_sorted(query, n_results=chunk_count)
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
additional_context = '\n'
for id_ in best_ids:
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
@ -151,7 +151,7 @@ def input_modifier(string):
user_input = match.group(1).strip()
# Get the most similar chunks
results = collector.get_sorted(user_input, n_results=chunk_count)
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
# Make the injection
string = string.replace('<|injection-point|>', '\n'.join(results))
@ -240,6 +240,10 @@ def ui():
with gr.Tab("Generation settings"):
chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)')
time_weight = gr.Slider(0, 1, value=params['time_weight'], label='Time weight', info='Defines the strength of the time weighting. 0 = no time weighting.')
chunk_count_initial = gr.Number(value=params['chunk_count_initial'], label='Initial chunk count', info='The number of closest-matching chunks retrieved for time weight reordering in chat mode. This should be >= chunk count. -1 = All chunks are retrieved. Only used if time_weight > 0.')
update_settings = gr.Button('Apply changes')
chunk_len = gr.Number(value=params['chunk_length'], label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
@ -250,4 +254,4 @@ def ui():
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False)
update_settings.click(apply_settings, [chunk_count], last_updated, show_progress=False)
update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)