From 2681f6f64049c48f43d3371072fe7ee0e101f6dd Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:03:18 -0300 Subject: [PATCH] Make superbooga & superboogav2 functional again (#5656) --- extensions/superbooga/chromadb.py | 56 ++------ extensions/superbooga/requirements.txt | 2 +- extensions/superboogav2/api.py | 33 ++--- extensions/superboogav2/benchmark.py | 14 +- extensions/superboogav2/chat_handler.py | 31 ++-- extensions/superboogav2/chromadb.py | 142 ++++++++----------- extensions/superboogav2/data_preprocessor.py | 40 +++--- extensions/superboogav2/data_processor.py | 25 ++-- extensions/superboogav2/download_urls.py | 5 +- extensions/superboogav2/notebook_handler.py | 7 +- extensions/superboogav2/optimize.py | 26 ++-- extensions/superboogav2/parameters.py | 12 +- extensions/superboogav2/requirements.txt | 4 +- extensions/superboogav2/script.py | 46 +++--- extensions/superboogav2/utils.py | 3 +- 15 files changed, 189 insertions(+), 257 deletions(-) diff --git a/extensions/superbooga/chromadb.py b/extensions/superbooga/chromadb.py index 1fb7a718..b16158e1 100644 --- a/extensions/superbooga/chromadb.py +++ b/extensions/superbooga/chromadb.py @@ -1,43 +1,24 @@ +import random + import chromadb import posthog -import torch from chromadb.config import Settings -from sentence_transformers import SentenceTransformer +from chromadb.utils import embedding_functions -from modules.logging_colors import logger - -logger.info('Intercepting all calls to posthog :)') +# Intercept calls to posthog posthog.capture = lambda *args, **kwargs: None -class Collecter(): +embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2") + + +class ChromaCollector(): def __init__(self): - pass + name = ''.join(random.choice('ab') for _ in range(10)) - def add(self, texts: list[str]): - pass - - def get(self, search_strings: list[str], n_results: int) -> list[str]: - pass - - def clear(self): - pass - - -class Embedder(): - def __init__(self): - pass - - def embed(self, text: str) -> list[torch.Tensor]: - pass - - -class ChromaCollector(Collecter): - def __init__(self, embedder: Embedder): - super().__init__() + self.name = name self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False)) - self.embedder = embedder - self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed) + self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder) self.ids = [] def add(self, texts: list[str]): @@ -102,24 +83,15 @@ class ChromaCollector(Collecter): return sorted(ids) def clear(self): - self.collection.delete(ids=self.ids) self.ids = [] - - -class SentenceTransformerEmbedder(Embedder): - def __init__(self) -> None: - self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") - self.embed = self.model.encode + self.chroma_client.delete_collection(name=self.name) + self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder) def make_collector(): - global embedder - return ChromaCollector(embedder) + return ChromaCollector() def add_chunks_to_collector(chunks, collector): collector.clear() collector.add(chunks) - - -embedder = SentenceTransformerEmbedder() diff --git a/extensions/superbooga/requirements.txt b/extensions/superbooga/requirements.txt index 73a60078..4b166568 100644 --- a/extensions/superbooga/requirements.txt +++ b/extensions/superbooga/requirements.txt @@ -1,5 +1,5 @@ beautifulsoup4==4.12.2 -chromadb==0.3.18 +chromadb==0.4.24 pandas==2.0.3 posthog==2.4.2 sentence_transformers==2.2.2 diff --git a/extensions/superboogav2/api.py b/extensions/superboogav2/api.py index 993e2b7d..552c1c2c 100644 --- a/extensions/superboogav2/api.py +++ b/extensions/superboogav2/api.py @@ -12,17 +12,16 @@ This module is responsible for the VectorDB API. It currently supports: import json from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from urllib.parse import urlparse, parse_qs from threading import Thread +from urllib.parse import parse_qs, urlparse +import extensions.superboogav2.parameters as parameters from modules import shared from modules.logging_colors import logger from .chromadb import ChromaCollector from .data_processor import process_and_add_to_collector -import extensions.superboogav2.parameters as parameters - class CustomThreadingHTTPServer(ThreadingHTTPServer): def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True): @@ -38,7 +37,6 @@ class Handler(BaseHTTPRequestHandler): self.collector = collector super().__init__(request, client_address, server) - def _send_412_error(self, message): self.send_response(412) self.send_header("Content-type", "application/json") @@ -46,7 +44,6 @@ class Handler(BaseHTTPRequestHandler): response = json.dumps({"error": message}) self.wfile.write(response.encode('utf-8')) - def _send_404_error(self): self.send_response(404) self.send_header("Content-type", "application/json") @@ -54,14 +51,12 @@ class Handler(BaseHTTPRequestHandler): response = json.dumps({"error": "Resource not found"}) self.wfile.write(response.encode('utf-8')) - def _send_400_error(self, error_message: str): self.send_response(400) self.send_header("Content-type", "application/json") self.end_headers() response = json.dumps({"error": error_message}) self.wfile.write(response.encode('utf-8')) - def _send_200_response(self, message: str): self.send_response(200) @@ -75,24 +70,21 @@ class Handler(BaseHTTPRequestHandler): self.wfile.write(response.encode('utf-8')) - def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str): if sort_param == parameters.SORT_DISTANCE: results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count) elif sort_param == parameters.SORT_ID: results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count) - else: # Default is dist + else: # Default is dist results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count) - + return { "results": results } - def do_GET(self): self._send_404_error() - def do_POST(self): try: content_length = int(self.headers['Content-Length']) @@ -107,7 +99,7 @@ class Handler(BaseHTTPRequestHandler): if corpus is None: self._send_412_error("Missing parameter 'corpus'") return - + clear_before_adding = body.get('clear_before_adding', False) metadata = body.get('metadata') process_and_add_to_collector(corpus, self.collector, clear_before_adding, metadata) @@ -118,7 +110,7 @@ class Handler(BaseHTTPRequestHandler): if corpus is None: self._send_412_error("Missing parameter 'metadata'") return - + self.collector.delete(ids_to_delete=None, where=metadata) self._send_200_response("Data successfully deleted") @@ -127,15 +119,15 @@ class Handler(BaseHTTPRequestHandler): if search_strings is None: self._send_412_error("Missing parameter 'search_strings'") return - + n_results = body.get('n_results') if n_results is None: n_results = parameters.get_chunk_count() - + max_token_count = body.get('max_token_count') if max_token_count is None: max_token_count = parameters.get_max_token_count() - + sort_param = query_params.get('sort', ['distance'])[0] results = self._handle_get(search_strings, n_results, max_token_count, sort_param) @@ -146,7 +138,6 @@ class Handler(BaseHTTPRequestHandler): except Exception as e: self._send_400_error(str(e)) - def do_DELETE(self): try: parsed_path = urlparse(self.path) @@ -161,12 +152,10 @@ class Handler(BaseHTTPRequestHandler): except Exception as e: self._send_400_error(str(e)) - def do_OPTIONS(self): self.send_response(200) self.end_headers() - def end_headers(self): self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Methods', '*') @@ -197,11 +186,11 @@ class APIManager: def stop_server(self): if self.server is not None: - logger.info(f'Stopping chromaDB API.') + logger.info('Stopping chromaDB API.') self.server.shutdown() self.server.server_close() self.server = None self.is_running = False def is_server_running(self): - return self.is_running \ No newline at end of file + return self.is_running diff --git a/extensions/superboogav2/benchmark.py b/extensions/superboogav2/benchmark.py index 46475a08..5d9331a7 100644 --- a/extensions/superboogav2/benchmark.py +++ b/extensions/superboogav2/benchmark.py @@ -9,23 +9,23 @@ The benchmark function will return the score as an integer. import datetime import json import os - from pathlib import Path -from .data_processor import process_and_add_to_collector, preprocess_text +from .data_processor import preprocess_text, process_and_add_to_collector from .parameters import get_chunk_count, get_max_token_count from .utils import create_metadata_source + def benchmark(config_path, collector): # Get the current system date sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"benchmark_{sysdate}.txt" - + # Open the log file in append mode with open(filename, 'a') as log: with open(config_path, 'r') as f: data = json.load(f) - + total_points = 0 max_points = 0 @@ -45,7 +45,7 @@ def benchmark(config_path, collector): for question_group in item["questions"]: question_variants = question_group["question_variants"] criteria = question_group["criteria"] - + for q in question_variants: max_points += len(criteria) processed_text = preprocess_text(q) @@ -54,7 +54,7 @@ def benchmark(config_path, collector): results = collector.get_sorted_by_dist(processed_text, n_results=get_chunk_count(), max_token_count=get_max_token_count()) points = 0 - + for c in criteria: for p in results: if c in p: @@ -69,4 +69,4 @@ def benchmark(config_path, collector): print(f'##Total points:\n\n{total_points}/{max_points}', file=log) - return total_points, max_points \ No newline at end of file + return total_points, max_points diff --git a/extensions/superboogav2/chat_handler.py b/extensions/superboogav2/chat_handler.py index 419b9264..01ff5894 100644 --- a/extensions/superboogav2/chat_handler.py +++ b/extensions/superboogav2/chat_handler.py @@ -4,16 +4,17 @@ This module is responsible for modifying the chat prompt and history. import re import extensions.superboogav2.parameters as parameters - +from extensions.superboogav2.utils import ( + create_context_text, + create_metadata_source +) from modules import chat, shared -from modules.text_generation import get_encoded_length -from modules.logging_colors import logger from modules.chat import load_character_memoized -from extensions.superboogav2.utils import create_context_text, create_metadata_source +from modules.logging_colors import logger +from modules.text_generation import get_encoded_length -from .data_processor import process_and_add_to_collector from .chromadb import ChromaCollector - +from .data_processor import process_and_add_to_collector CHAT_METADATA = create_metadata_source('automatic-chat-insert') @@ -21,17 +22,17 @@ CHAT_METADATA = create_metadata_source('automatic-chat-insert') def _remove_tag_if_necessary(user_input: str): if not parameters.get_is_manual(): return user_input - + return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input) def _should_query(input: str): if not parameters.get_is_manual(): return True - + if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE): return True - + return False @@ -69,7 +70,7 @@ def _concatinate_history(history: dict, state: dict): if len(exchange) >= 2: full_history_text += _format_single_exchange(bot_name, exchange[1]) - return full_history_text[:-1] # Remove the last new line. + return full_history_text[:-1] # Remove the last new line. def _hijack_last(context_text: str, history: dict, max_len: int, state: dict): @@ -82,20 +83,20 @@ def _hijack_last(context_text: str, history: dict, max_len: int, state: dict): for i, messages in enumerate(reversed(history['internal'])): for j, message in enumerate(reversed(messages)): num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message)) - + # TODO: This is an extremely naive solution. A more robust implementation must be made. if history_tokens + num_context_tokens <= max_len: # This message can be replaced replace_position = (i, j) - + history_tokens += num_message_tokens - + if replace_position is None: logger.warn("The provided context_text is too long to replace any message in the history.") else: # replace the message at replace_position with context_text i, j = replace_position - history['internal'][-i-1][-j-1] = context_text + history['internal'][-i - 1][-j - 1] = context_text def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs): @@ -120,5 +121,5 @@ def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector user_input = create_context_text(results) + user_input elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT: _hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state) - + return chat.generate_chat_prompt(user_input, state, **kwargs) diff --git a/extensions/superboogav2/chromadb.py b/extensions/superboogav2/chromadb.py index 0da2d8f9..3381fb14 100644 --- a/extensions/superboogav2/chromadb.py +++ b/extensions/superboogav2/chromadb.py @@ -1,42 +1,23 @@ -import threading -import chromadb -import posthog -import torch import math +import random +import threading +import chromadb import numpy as np -import extensions.superboogav2.parameters as parameters - +import posthog from chromadb.config import Settings -from sentence_transformers import SentenceTransformer +from chromadb.utils import embedding_functions +import extensions.superboogav2.parameters as parameters from modules.logging_colors import logger -from modules.text_generation import encode, decode +from modules.text_generation import decode, encode -logger.debug('Intercepting all calls to posthog.') +# Intercept calls to posthog posthog.capture = lambda *args, **kwargs: None -class Collecter(): - def __init__(self): - pass +embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2") - def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]): - pass - - def get(self, search_strings: list[str], n_results: int) -> list[str]: - pass - - def clear(self): - pass - - -class Embedder(): - def __init__(self): - pass - - def embed(self, text: str) -> list[torch.Tensor]: - pass class Info: def __init__(self, start_index, text_with_context, distance, id): @@ -58,7 +39,7 @@ class Info: elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY: # Arithmetic mean return (self.distance + other_info.distance) / 2 - else: # Min is default + else: # Min is default return min(self.distance, other_info.distance) def merge_with(self, other_info): @@ -66,7 +47,7 @@ class Info: s2 = other_info.text_with_context s1_start = self.start_index s2_start = other_info.start_index - + new_dist = self.calculate_distance(other_info) if self.should_merge(s1, s2, s1_start, s2_start): @@ -84,55 +65,58 @@ class Info: return Info(s2_start, s2 + s1[overlap:], new_dist, other_info.id) return None - + @staticmethod def should_merge(s1, s2, s1_start, s2_start): # Check if s1 and s2 are adjacent or overlapping s1_end = s1_start + len(s1) s2_end = s2_start + len(s2) - + return not (s1_end < s2_start or s2_end < s1_start) -class ChromaCollector(Collecter): - def __init__(self, embedder: Embedder): - super().__init__() + +class ChromaCollector(): + def __init__(self): + name = ''.join(random.choice('ab') for _ in range(10)) + + self.name = name self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False)) - self.embedder = embedder - self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed) + self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder) + self.ids = [] self.id_to_info = {} self.embeddings_cache = {} - self.lock = threading.Lock() # Locking so the server doesn't break. + self.lock = threading.Lock() # Locking so the server doesn't break. def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None): with self.lock: assert metadatas is None or len(metadatas) == len(texts), "metadatas must be None or have the same length as texts" - - if len(texts) == 0: + + if len(texts) == 0: return new_ids = self._get_new_ids(len(texts)) (existing_texts, existing_embeddings, existing_ids, existing_metas), \ - (non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas) + (non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas) # If there are any already existing texts, add them all at once. if existing_texts: logger.info(f'Adding {len(existing_embeddings)} cached embeddings.') args = {'embeddings': existing_embeddings, 'documents': existing_texts, 'ids': existing_ids} - if metadatas is not None: + if metadatas is not None: args['metadatas'] = existing_metas self.collection.add(**args) # If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead. if non_existing_texts: - non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist() + non_existing_embeddings = embedder(non_existing_texts) for text, embedding in zip(non_existing_texts, non_existing_embeddings): self.embeddings_cache[text] = embedding logger.info(f'Adding {len(non_existing_embeddings)} new embeddings.') args = {'embeddings': non_existing_embeddings, 'documents': non_existing_texts, 'ids': non_existing_ids} - if metadatas is not None: + if metadatas is not None: args['metadatas'] = non_existing_metas self.collection.add(**args) @@ -145,7 +129,6 @@ class ChromaCollector(Collecter): self.id_to_info.update(new_info) self.ids.extend(new_ids) - def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]): existing_texts, non_existing_texts = [], [] existing_embeddings = [] @@ -169,7 +152,6 @@ class ChromaCollector(Collecter): return (existing_texts, existing_embeddings, existing_ids, existing_metas), \ (non_existing_texts, non_existing_ids, non_existing_metas) - def _get_new_ids(self, num_new_ids: int): if self.ids: max_existing_id = max(int(id_) for id_ in self.ids) @@ -178,7 +160,6 @@ class ChromaCollector(Collecter): return [str(i + max_existing_id + 1) for i in range(num_new_ids)] - def _find_min_max_start_index(self): max_index, min_index = 0, float('inf') for _, val in self.id_to_info.items(): @@ -188,34 +169,34 @@ class ChromaCollector(Collecter): min_index = val['start_index'] return min_index, max_index - - # NB: Does not make sense to weigh excerpts from different documents. + # NB: Does not make sense to weigh excerpts from different documents. # But let's say that's the user's problem. Perfect world scenario: # Apply time weighing to different documents. For each document, then, add # separate time weighing. + def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float): - sigmoid = lambda x: 1 / (1 + np.exp(-x)) - + def sigmoid(x): + return 1 / (1 + np.exp(-x)) + weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len)) # Scale to [0,time_power] and shift it up to [1-time_power, 1] - weights = weights - min(weights) + weights = weights - min(weights) weights = weights * (time_power / max(weights)) - weights = weights + (1 - time_power) + weights = weights + (1 - time_power) # Reverse the weights - weights = weights[::-1] + weights = weights[::-1] for info in infos: index = info.start_index info.distance *= weights[index] - def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float): # Ensure there are infos to filter if not infos: return [] - + # Find info with minimum distance min_info = min(infos, key=lambda x: x.distance) @@ -231,7 +212,6 @@ class ChromaCollector(Collecter): return filtered_infos - def _merge_infos(self, infos: list[Info]): merged_infos = [] current_info = infos[0] @@ -247,8 +227,8 @@ class ChromaCollector(Collecter): merged_infos.append(current_info) return merged_infos - # Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering. + def _get_documents_ids_distances(self, search_strings: list[str], n_results: int): n_results = min(len(self.ids), n_results) if n_results == 0: @@ -262,11 +242,11 @@ class ChromaCollector(Collecter): for search_string in search_strings: result = self.collection.query(query_texts=search_string, n_results=math.ceil(n_results / len(search_strings)), include=['distances']) - curr_infos = [Info(start_index=self.id_to_info[id]['start_index'], - text_with_context=self.id_to_info[id]['text_with_context'], - distance=distance, id=id) + curr_infos = [Info(start_index=self.id_to_info[id]['start_index'], + text_with_context=self.id_to_info[id]['text_with_context'], + distance=distance, id=id) for id, distance in zip(result['ids'][0], result['distances'][0])] - + self._apply_sigmoid_time_weighing(infos=curr_infos, document_len=max_start_index - min_start_index + 1, time_steepness=parameters.get_time_steepness(), time_power=parameters.get_time_power()) curr_infos = self._filter_outliers_by_median_distance(curr_infos, parameters.get_significant_level()) infos.extend(curr_infos) @@ -279,23 +259,23 @@ class ChromaCollector(Collecter): distances = [inf.distance for inf in infos] return texts_with_context, ids, distances - # Get chunks by similarity + def get(self, search_strings: list[str], n_results: int) -> list[str]: with self.lock: documents, _, _ = self._get_documents_ids_distances(search_strings, n_results) return documents - # Get ids by similarity + def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: with self.lock: _, ids, _ = self._get_documents_ids_distances(search_strings, n_results) return ids - - + # Cutoff token count + def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int): # TODO: Move to caller; We add delimiters there which might go over the limit. current_token_count = 0 @@ -308,7 +288,7 @@ class ChromaCollector(Collecter): # If adding this document would exceed the max token count, # truncate the document to fit within the limit. remaining_tokens = max_token_count - current_token_count - + truncated_doc = decode(doc_tokens[:remaining_tokens], skip_special_tokens=True) return_documents.append(truncated_doc) break @@ -317,29 +297,28 @@ class ChromaCollector(Collecter): current_token_count += doc_token_count return return_documents - # Get chunks by similarity and then sort by ids + def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]: with self.lock: documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results) sorted_docs = [x for _, x in sorted(zip(ids, documents))] return self._get_documents_up_to_token_count(sorted_docs, max_token_count) - - + # Get chunks by similarity and then sort by distance (lowest distance is last). + def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]: with self.lock: documents, _, distances = self._get_documents_ids_distances(search_strings, n_results) - sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest - + sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest + # If a document is truncated or competely skipped, it would be with high distance. return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count) - return_documents.reverse() # highest -> lowest + return_documents.reverse() # highest -> lowest return return_documents - def delete(self, ids_to_delete: list[str], where: dict): with self.lock: @@ -354,23 +333,16 @@ class ChromaCollector(Collecter): logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.') - def clear(self): with self.lock: self.chroma_client.reset() - self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed) + self.ids = [] - self.id_to_info = {} + self.chroma_client.delete_collection(name=self.name) + self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder) logger.info('Successfully cleared all records and reset chromaDB.') -class SentenceTransformerEmbedder(Embedder): - def __init__(self) -> None: - logger.debug('Creating Sentence Embedder...') - self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") - self.embed = self.model.encode - - def make_collector(): - return ChromaCollector(SentenceTransformerEmbedder()) \ No newline at end of file + return ChromaCollector() diff --git a/extensions/superboogav2/data_preprocessor.py b/extensions/superboogav2/data_preprocessor.py index cbd14b6b..1f354cf2 100644 --- a/extensions/superboogav2/data_preprocessor.py +++ b/extensions/superboogav2/data_preprocessor.py @@ -11,32 +11,29 @@ This module contains utils for preprocessing the text before converting it to em * removing specific parts of speech (adverbs and interjections) - TextSummarizer extracts the most important sentences from a long string using text-ranking. """ -import pytextrank -import string -import spacy import math -import nltk import re +import string +import nltk +import spacy from nltk.corpus import stopwords from nltk.stem import WordNetLemmatizer from num2words import num2words class TextPreprocessorBuilder: - # Define class variables as None initially + # Define class variables as None initially _stop_words = set(stopwords.words('english')) _lemmatizer = WordNetLemmatizer() - + # Some of the functions are expensive. We cache the results. _lemmatizer_cache = {} _pos_remove_cache = {} - def __init__(self, text: str): self.text = text - def to_lower(self): # Match both words and non-word characters tokens = re.findall(r'\b\w+\b|\W+', self.text) @@ -49,7 +46,6 @@ class TextPreprocessorBuilder: self.text = "".join(tokens) return self - def num_to_word(self, min_len: int = 1): # Match both words and non-word characters tokens = re.findall(r'\b\w+\b|\W+', self.text) @@ -58,11 +54,10 @@ class TextPreprocessorBuilder: if token.isdigit() and len(token) >= min_len: # This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers) # 740700 will become "seven hundred and forty thousand seven hundred". - tokens[i] = num2words(int(token)).replace(",","") # Remove commas from num2words. + tokens[i] = num2words(int(token)).replace(",", "") # Remove commas from num2words. self.text = "".join(tokens) return self - def num_to_char_long(self, min_len: int = 1): # Match both words and non-word characters tokens = re.findall(r'\b\w+\b|\W+', self.text) @@ -71,11 +66,13 @@ class TextPreprocessorBuilder: if token.isdigit() and len(token) >= min_len: # This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers) # 740700 will become HHHHHHEEEEEAAAAHHHAAA - convert_token = lambda token: ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1] + def convert_token(token): + return ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1] + tokens[i] = convert_token(tokens[i]) self.text = "".join(tokens) return self - + def num_to_char(self, min_len: int = 1): # Match both words and non-word characters tokens = re.findall(r'\b\w+\b|\W+', self.text) @@ -87,15 +84,15 @@ class TextPreprocessorBuilder: tokens[i] = ''.join(chr(int(digit) + 65) for digit in token) self.text = "".join(tokens) return self - + def merge_spaces(self): self.text = re.sub(' +', ' ', self.text) return self - + def strip(self): self.text = self.text.strip() return self - + def remove_punctuation(self): self.text = self.text.translate(str.maketrans('', '', string.punctuation)) return self @@ -103,7 +100,7 @@ class TextPreprocessorBuilder: def remove_stopwords(self): self.text = "".join([word for word in re.findall(r'\b\w+\b|\W+', self.text) if word not in TextPreprocessorBuilder._stop_words]) return self - + def remove_specific_pos(self): """ In the English language, adverbs and interjections rarely provide meaningul information. @@ -140,7 +137,7 @@ class TextPreprocessorBuilder: if processed_text: self.text = processed_text return self - + new_text = "".join([TextPreprocessorBuilder._lemmatizer.lemmatize(word) for word in re.findall(r'\b\w+\b|\W+', self.text)]) TextPreprocessorBuilder._lemmatizer_cache[self.text] = new_text self.text = new_text @@ -150,6 +147,7 @@ class TextPreprocessorBuilder: def build(self): return self.text + class TextSummarizer: _nlp_pipeline = None _cache = {} @@ -165,7 +163,7 @@ class TextSummarizer: @staticmethod def process_long_text(text: str, min_num_sent: int) -> list[str]: """ - This function applies a text summarization process on a given text string, extracting + This function applies a text summarization process on a given text string, extracting the most important sentences based on the principle that 20% of the content is responsible for 80% of the meaning (the Pareto Principle). @@ -193,7 +191,7 @@ class TextSummarizer: else: result = [text] - + # Store the result in cache before returning it TextSummarizer._cache[cache_key] = result - return result \ No newline at end of file + return result diff --git a/extensions/superboogav2/data_processor.py b/extensions/superboogav2/data_processor.py index f019f427..0a96d4a4 100644 --- a/extensions/superboogav2/data_processor.py +++ b/extensions/superboogav2/data_processor.py @@ -1,16 +1,17 @@ """ -This module is responsible for processing the corpus and feeding it into chromaDB. It will receive a corpus of text. +This module is responsible for processing the corpus and feeding it into chromaDB. It will receive a corpus of text. It will then split it into chunks of specified length. For each of those chunks, it will append surrounding context. It will only include full words. """ -import re import bisect +import re import extensions.superboogav2.parameters as parameters -from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer from .chromadb import ChromaCollector +from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer + def preprocess_text_no_summary(text) -> str: builder = TextPreprocessorBuilder(text) @@ -42,7 +43,7 @@ def preprocess_text_no_summary(text) -> str: builder.num_to_char(parameters.get_min_num_length()) elif parameters.get_num_conversion_strategy() == parameters.NUM_TO_CHAR_LONG_METHOD: builder.num_to_char_long(parameters.get_min_num_length()) - + return builder.build() @@ -53,10 +54,10 @@ def preprocess_text(text) -> list[str]: def _create_chunks_with_context(corpus, chunk_len, context_left, context_right): """ - This function takes a corpus of text and splits it into chunks of a specified length, - then adds a specified amount of context to each chunk. The context is added by first - going backwards from the start of the chunk and then going forwards from the end of the - chunk, ensuring that the context includes only whole words and that the total context length + This function takes a corpus of text and splits it into chunks of a specified length, + then adds a specified amount of context to each chunk. The context is added by first + going backwards from the start of the chunk and then going forwards from the end of the + chunk, ensuring that the context includes only whole words and that the total context length does not exceed the specified limit. This function uses binary search for efficiency. Returns: @@ -102,7 +103,7 @@ def _create_chunks_with_context(corpus, chunk_len, context_left, context_right): # Combine all the words in the context range (before, chunk, and after) chunk_with_context = ''.join(words[context_start_index:context_end_index]) chunks_with_context.append(chunk_with_context) - + # Determine the start index of the chunk with context chunk_with_context_start_index = word_start_indices[context_start_index] chunk_with_context_start_indices.append(chunk_with_context_start_index) @@ -125,9 +126,9 @@ def _clear_chunks(data_chunks, data_chunks_with_context, data_chunk_starting_ind seen_chunk_start = seen_chunks.get(chunk) if seen_chunk_start: # If we've already seen this exact chunk, and the context around it it very close to the seen chunk, then skip it. - if abs(seen_chunk_start-index) < parameters.get_delta_start(): + if abs(seen_chunk_start - index) < parameters.get_delta_start(): continue - + distinct_data_chunks.append(chunk) distinct_data_chunks_with_context.append(context) distinct_data_chunk_starting_indices.append(index) @@ -206,4 +207,4 @@ def process_and_add_to_collector(corpus: str, collector: ChromaCollector, clear_ if clear_collector_before_adding: collector.clear() - collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata]*len(data_chunks) if metadata is not None else None) \ No newline at end of file + collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata] * len(data_chunks) if metadata is not None else None) diff --git a/extensions/superboogav2/download_urls.py b/extensions/superboogav2/download_urls.py index ad2726b5..5b5a2e17 100644 --- a/extensions/superboogav2/download_urls.py +++ b/extensions/superboogav2/download_urls.py @@ -1,7 +1,7 @@ import concurrent.futures -import requests import re +import requests from bs4 import BeautifulSoup import extensions.superboogav2.parameters as parameters @@ -9,6 +9,7 @@ import extensions.superboogav2.parameters as parameters from .data_processor import process_and_add_to_collector from .utils import create_metadata_source + def _download_single(url): response = requests.get(url, timeout=5) if response.status_code == 200: @@ -62,4 +63,4 @@ def feed_url_into_collector(urls, collector): text = '\n'.join([s.strip() for s in strings]) all_text += text - process_and_add_to_collector(all_text, collector, False, create_metadata_source('url-download')) \ No newline at end of file + process_and_add_to_collector(all_text, collector, False, create_metadata_source('url-download')) diff --git a/extensions/superboogav2/notebook_handler.py b/extensions/superboogav2/notebook_handler.py index 7b864349..d07a2098 100644 --- a/extensions/superboogav2/notebook_handler.py +++ b/extensions/superboogav2/notebook_handler.py @@ -4,13 +4,12 @@ This module is responsible for handling and modifying the notebook text. import re import extensions.superboogav2.parameters as parameters - -from modules import shared -from modules.logging_colors import logger from extensions.superboogav2.utils import create_context_text +from modules.logging_colors import logger from .data_processor import preprocess_text + def _remove_special_tokens(string): pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)' return re.sub(pattern, '', string) @@ -37,4 +36,4 @@ def input_modifier_internal(string, collector, is_chat): # Make the injection string = string.replace('<|injection-point|>', create_context_text(results)) - return _remove_special_tokens(string) \ No newline at end of file + return _remove_special_tokens(string) diff --git a/extensions/superboogav2/optimize.py b/extensions/superboogav2/optimize.py index acebf212..ebdd03c6 100644 --- a/extensions/superboogav2/optimize.py +++ b/extensions/superboogav2/optimize.py @@ -3,22 +3,24 @@ This module implements a hyperparameter optimization routine for the embedding a Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found. """ -import re +import hashlib import json -import optuna +import logging +import re + import gradio as gr import numpy as np -import logging -import hashlib -logging.getLogger('optuna').setLevel(logging.WARNING) +import optuna -import extensions.superboogav2.parameters as parameters +logging.getLogger('optuna').setLevel(logging.WARNING) from pathlib import Path +import extensions.superboogav2.parameters as parameters +from modules.logging_colors import logger + from .benchmark import benchmark from .parameters import Parameters -from modules.logging_colors import logger # Format the parameters into markdown format. @@ -28,7 +30,7 @@ def _markdown_hyperparams(): # Escape any markdown syntax param_name = re.sub(r"([_*\[\]()~`>#+-.!])", r"\\\1", param_name) param_value_default = re.sub(r"([_*\[\]()~`>#+-.!])", r"\\\1", str(param_value['default'])) if param_value['default'] else ' ' - + res.append('* {}: **{}**'.format(param_name, param_value_default)) return '\n'.join(res) @@ -49,13 +51,13 @@ def _convert_np_types(params): # Set the default values for the hyperparameters. def _set_hyperparameters(params): for param_name, param_value in params.items(): - if param_name in Parameters.getInstance().hyperparameters: + if param_name in Parameters.getInstance().hyperparameters: Parameters.getInstance().hyperparameters[param_name]['default'] = param_value # Check if the parameter is for optimization. def _is_optimization_param(val): - is_opt = val.get('should_optimize', False) # Either does not exist or is false + is_opt = val.get('should_optimize', False) # Either does not exist or is false return is_opt @@ -67,7 +69,7 @@ def _get_params_hash(params): def optimize(collector, progress=gr.Progress()): # Inform the user that something is happening. - progress(0, desc=f'Setting Up...') + progress(0, desc='Setting Up...') # Track the current step current_step = 0 @@ -132,4 +134,4 @@ def optimize(collector, progress=gr.Progress()): with open('best_params.json', 'w') as fp: json.dump(_convert_np_types(best_params), fp, indent=4) - return str_result \ No newline at end of file + return str_result diff --git a/extensions/superboogav2/parameters.py b/extensions/superboogav2/parameters.py index 1cada46a..8bb2d1a6 100644 --- a/extensions/superboogav2/parameters.py +++ b/extensions/superboogav2/parameters.py @@ -1,18 +1,16 @@ """ -This module provides a singleton class `Parameters` that is used to manage all hyperparameters for the embedding application. +This module provides a singleton class `Parameters` that is used to manage all hyperparameters for the embedding application. It expects a JSON file in `extensions/superboogav2/config.json`. -Each element in the JSON must have a `default` value which will be used for the current run. Elements can have `categories`. -These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`, +Each element in the JSON must have a `default` value which will be used for the current run. Elements can have `categories`. +These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`, then the optimizer will only ever use the default value. """ +import json from pathlib import Path -import json - from modules.logging_colors import logger - NUM_TO_WORD_METHOD = 'Number to Word' NUM_TO_CHAR_METHOD = 'Number to Char' NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char' @@ -366,4 +364,4 @@ def set_api_port(value: int): def set_api_on(value: bool): - Parameters.getInstance().hyperparameters['api_on']['default'] = value \ No newline at end of file + Parameters.getInstance().hyperparameters['api_on']['default'] = value diff --git a/extensions/superboogav2/requirements.txt b/extensions/superboogav2/requirements.txt index 748bacf1..d9031167 100644 --- a/extensions/superboogav2/requirements.txt +++ b/extensions/superboogav2/requirements.txt @@ -1,5 +1,5 @@ beautifulsoup4==4.12.2 -chromadb==0.3.18 +chromadb==0.4.24 lxml optuna pandas==2.0.3 @@ -7,4 +7,4 @@ posthog==2.4.2 sentence_transformers==2.2.2 spacy pytextrank -num2words \ No newline at end of file +num2words diff --git a/extensions/superboogav2/script.py b/extensions/superboogav2/script.py index 66f56e29..77c5cced 100644 --- a/extensions/superboogav2/script.py +++ b/extensions/superboogav2/script.py @@ -7,28 +7,29 @@ from pathlib import Path # Point to where nltk will find the required data. os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve()) -import textwrap import codecs +import textwrap + import gradio as gr import extensions.superboogav2.parameters as parameters - -from modules.logging_colors import logger from modules import shared +from modules.logging_colors import logger -from .utils import create_metadata_source -from .chromadb import make_collector -from .download_urls import feed_url_into_collector -from .data_processor import process_and_add_to_collector -from .benchmark import benchmark -from .optimize import optimize -from .notebook_handler import input_modifier_internal -from .chat_handler import custom_generate_chat_prompt_internal from .api import APIManager +from .benchmark import benchmark +from .chat_handler import custom_generate_chat_prompt_internal +from .chromadb import make_collector +from .data_processor import process_and_add_to_collector +from .download_urls import feed_url_into_collector +from .notebook_handler import input_modifier_internal +from .optimize import optimize +from .utils import create_metadata_source collector = None api_manager = None + def setup(): global collector global api_manager @@ -38,6 +39,7 @@ def setup(): if parameters.get_api_on(): api_manager.start_server(parameters.get_api_port()) + def _feed_data_into_collector(corpus): yield '### Processing data...' process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text')) @@ -87,7 +89,7 @@ def _get_optimizable_settings() -> list: preprocess_pipeline.append('Merge Spaces') if parameters.should_strip(): preprocess_pipeline.append('Strip Edges') - + return [ parameters.get_time_power(), parameters.get_time_steepness(), @@ -104,8 +106,8 @@ def _get_optimizable_settings() -> list: ] -def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, - preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count, +def _apply_settings(optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, + preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count, chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup): logger.debug('Applying settings.') @@ -240,7 +242,7 @@ def ui(): with gr.Tab("File input"): file_input = gr.File(label='Input file', type='binary') update_file = gr.Button('Load data') - + with gr.Tab("Settings"): with gr.Accordion("Processing settings", open=True): chunk_len = gr.Textbox(value=parameters.get_chunk_len(), label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".') @@ -305,19 +307,16 @@ def ui(): optimize_button = gr.Button('Optimize') optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True) - clear_button = gr.Button('❌ Clear Data') - with gr.Column(): last_updated = gr.Markdown() - all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, - preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count, + all_params = [optimization_steps, time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, + preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count, chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup] - optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, - preprocess_pipeline, chunk_count, context_len, chunk_len] - + optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, + preprocess_pipeline, chunk_count, context_len, chunk_len] update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False) update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False) @@ -326,7 +325,6 @@ def ui(): optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True) clear_button.click(_clear_data, [], last_updated, show_progress=False) - optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False) time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False) time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False) @@ -352,4 +350,4 @@ def ui(): chunk_regex.input(fn=_apply_settings, inputs=all_params, show_progress=False) chunk_len.input(fn=_apply_settings, inputs=all_params, show_progress=False) threads.input(fn=_apply_settings, inputs=all_params, show_progress=False) - strong_cleanup.input(fn=_apply_settings, inputs=all_params, show_progress=False) \ No newline at end of file + strong_cleanup.input(fn=_apply_settings, inputs=all_params, show_progress=False) diff --git a/extensions/superboogav2/utils.py b/extensions/superboogav2/utils.py index 89b367ea..df84650b 100644 --- a/extensions/superboogav2/utils.py +++ b/extensions/superboogav2/utils.py @@ -4,6 +4,7 @@ This module contains common functions across multiple other modules. import extensions.superboogav2.parameters as parameters + # Create the context using the prefix + data_separator + postfix from parameters. def create_context_text(results): context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix() @@ -13,4 +14,4 @@ def create_context_text(results): # Create metadata with the specified source def create_metadata_source(source: str): - return {'source': source} \ No newline at end of file + return {'source': source}