From ee674afa50d9389659c683a2840bbd0853cc3777 Mon Sep 17 00:00:00 2001 From: Luis Lopez <29587742+toast22a@users.noreply.github.com> Date: Thu, 25 May 2023 21:22:45 +0800 Subject: [PATCH] Add superbooga time weighted history retrieval (#2080) --- extensions/superbooga/chromadb.py | 40 ++++++++++++++++++++++++------- extensions/superbooga/script.py | 30 +++++++++++++---------- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/extensions/superbooga/chromadb.py b/extensions/superbooga/chromadb.py index 75efe70b..8120e798 100644 --- a/extensions/superbooga/chromadb.py +++ b/extensions/superbooga/chromadb.py @@ -47,34 +47,58 @@ class ChromaCollector(Collecter): self.ids = [f"id{i}" for i in range(len(texts))] self.collection.add(documents=texts, ids=self.ids) - def get_documents_and_ids(self, search_strings: list[str], n_results: int): + def get_documents_ids_distances(self, search_strings: list[str], n_results: int): n_results = min(len(self.ids), n_results) if n_results == 0: return [], [] - result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents']) + result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents', 'distances']) documents = result['documents'][0] ids = list(map(lambda x: int(x[2:]), result['ids'][0])) - return documents, ids + distances = result['distances'][0] + return documents, ids, distances # Get chunks by similarity def get(self, search_strings: list[str], n_results: int) -> list[str]: - documents, _ = self.get_documents_and_ids(search_strings, n_results) + documents, _, _ = self.get_documents_ids_distances(search_strings, n_results) return documents # Get ids by similarity def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: - _, ids = self.get_documents_and_ids(search_strings, n_results) + _, ids, _ = self.get_documents_ids_distances(search_strings, n_results) return ids # Get chunks by similarity and then sort by insertion order def get_sorted(self, search_strings: list[str], n_results: int) -> list[str]: - documents, ids = self.get_documents_and_ids(search_strings, n_results) + documents, ids, _ = self.get_documents_ids_distances(search_strings, n_results) return [x for _, x in sorted(zip(ids, documents))] + # Multiply distance by factor within [0, time_weight] where more recent is lower + def apply_time_weight_to_distances(self, ids: list[int], distances: list[float], time_weight: float = 1.0) -> list[float]: + if len(self.ids) <= 1: + return distances.copy() + + return [distance * (1 - _id / (len(self.ids) - 1) * time_weight) for _id, distance in zip(ids, distances)] + # Get ids by similarity and then sort by insertion order - def get_ids_sorted(self, search_strings: list[str], n_results: int) -> list[str]: - _, ids = self.get_documents_and_ids(search_strings, n_results) + def get_ids_sorted(self, search_strings: list[str], n_results: int, n_initial: int = None, time_weight: float = 1.0) -> list[str]: + do_time_weight = time_weight > 0 + if not (do_time_weight and n_initial is not None): + n_initial = n_results + elif n_initial == -1: + n_initial = len(self.ids) + + if n_initial < n_results: + raise ValueError(f"n_initial {n_initial} should be >= n_results {n_results}") + + _, ids, distances = self.get_documents_ids_distances(search_strings, n_initial) + if do_time_weight: + distances_w = self.apply_time_weight_to_distances(ids, distances, time_weight=time_weight) + results = zip(ids, distances, distances_w) + results = sorted(results, key=lambda x: x[2])[:n_results] + results = sorted(results, key=lambda x: x[0]) + ids = [x[0] for x in results] + return sorted(ids) def clear(self): diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index f36f6b01..66c79b3d 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -12,6 +12,8 @@ from .download_urls import download_urls params = { 'chunk_count': 5, + 'chunk_count_initial': 10, + 'time_weight': 0, 'chunk_length': 700, 'chunk_separator': '', 'strong_cleanup': False, @@ -20,7 +22,6 @@ params = { collector = make_collector() chat_collector = make_collector() -chunk_count = 5 def feed_data_into_collector(corpus, chunk_len, chunk_sep): @@ -83,13 +84,12 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads) yield i -def apply_settings(_chunk_count): - global chunk_count - chunk_count = int(_chunk_count) - settings_to_display = { - 'chunk_count': chunk_count, - } - +def apply_settings(chunk_count, chunk_count_initial, time_weight): + global params + params['chunk_count'] = int(chunk_count) + params['chunk_count_initial'] = int(chunk_count_initial) + params['time_weight'] = time_weight + settings_to_display = {k: params[k] for k in params if k in ['chunk_count', 'chunk_count_initial', 'time_weight']} yield f"The following settings are now active: {str(settings_to_display)}" @@ -97,7 +97,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): global chat_collector if state['mode'] == 'instruct': - results = collector.get_sorted(user_input, n_results=chunk_count) + results = collector.get_sorted(user_input, n_results=params['chunk_count']) additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results) user_input += additional_context else: @@ -108,7 +108,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n" return output - if len(shared.history['internal']) > chunk_count and user_input != '': + if len(shared.history['internal']) > params['chunk_count'] and user_input != '': chunks = [] hist_size = len(shared.history['internal']) for i in range(hist_size-1): @@ -117,7 +117,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): add_chunks_to_collector(chunks, chat_collector) query = '\n'.join(shared.history['internal'][-1] + [user_input]) try: - best_ids = chat_collector.get_ids_sorted(query, n_results=chunk_count) + best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight']) additional_context = '\n' for id_ in best_ids: if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>': @@ -151,7 +151,7 @@ def input_modifier(string): user_input = match.group(1).strip() # Get the most similar chunks - results = collector.get_sorted(user_input, n_results=chunk_count) + results = collector.get_sorted(user_input, n_results=params['chunk_count']) # Make the injection string = string.replace('<|injection-point|>', '\n'.join(results)) @@ -240,6 +240,10 @@ def ui(): with gr.Tab("Generation settings"): chunk_count = gr.Number(value=params['chunk_count'], label='Chunk count', info='The number of closest-matching chunks to include in the prompt.') + gr.Markdown('Time weighting (optional, used in to make recently added chunks more likely to appear)') + time_weight = gr.Slider(0, 1, value=params['time_weight'], label='Time weight', info='Defines the strength of the time weighting. 0 = no time weighting.') + chunk_count_initial = gr.Number(value=params['chunk_count_initial'], label='Initial chunk count', info='The number of closest-matching chunks retrieved for time weight reordering in chat mode. This should be >= chunk count. -1 = All chunks are retrieved. Only used if time_weight > 0.') + update_settings = gr.Button('Apply changes') chunk_len = gr.Number(value=params['chunk_length'], label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".') @@ -250,4 +254,4 @@ def ui(): update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_sep], last_updated, show_progress=False) update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_sep, strong_cleanup, threads], last_updated, show_progress=False) update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_sep], last_updated, show_progress=False) - update_settings.click(apply_settings, [chunk_count], last_updated, show_progress=False) + update_settings.click(apply_settings, [chunk_count, chunk_count_initial, time_weight], last_updated, show_progress=False)