Make superbooga & superboogav2 functional again (#5656)

This commit is contained in:
oobabooga 2024-03-07 15:03:18 -03:00 committed by GitHub
parent bae14c8f13
commit 2681f6f640
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194
15 changed files with 189 additions and 257 deletions

View file

@ -1,43 +1,24 @@
import random
import chromadb import chromadb
import posthog import posthog
import torch
from chromadb.config import Settings from chromadb.config import Settings
from sentence_transformers import SentenceTransformer from chromadb.utils import embedding_functions
from modules.logging_colors import logger # Intercept calls to posthog
logger.info('Intercepting all calls to posthog :)')
posthog.capture = lambda *args, **kwargs: None posthog.capture = lambda *args, **kwargs: None
class Collecter(): embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
class ChromaCollector():
def __init__(self): def __init__(self):
pass name = ''.join(random.choice('ab') for _ in range(10))
def add(self, texts: list[str]): self.name = name
pass
def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
def clear(self):
pass
class Embedder():
def __init__(self):
pass
def embed(self, text: str) -> list[torch.Tensor]:
pass
class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder):
super().__init__()
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False)) self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
self.collection = self.chroma_client.create_collection(name="context", embedding_function=embedder.embed)
self.ids = [] self.ids = []
def add(self, texts: list[str]): def add(self, texts: list[str]):
@ -102,24 +83,15 @@ class ChromaCollector(Collecter):
return sorted(ids) return sorted(ids)
def clear(self): def clear(self):
self.collection.delete(ids=self.ids)
self.ids = [] self.ids = []
self.chroma_client.delete_collection(name=self.name)
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
def make_collector(): def make_collector():
global embedder return ChromaCollector()
return ChromaCollector(embedder)
def add_chunks_to_collector(chunks, collector): def add_chunks_to_collector(chunks, collector):
collector.clear() collector.clear()
collector.add(chunks) collector.add(chunks)
embedder = SentenceTransformerEmbedder()

View file

@ -1,5 +1,5 @@
beautifulsoup4==4.12.2 beautifulsoup4==4.12.2
chromadb==0.3.18 chromadb==0.4.24
pandas==2.0.3 pandas==2.0.3
posthog==2.4.2 posthog==2.4.2
sentence_transformers==2.2.2 sentence_transformers==2.2.2

View file

@ -12,17 +12,16 @@ This module is responsible for the VectorDB API. It currently supports:
import json import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from urllib.parse import urlparse, parse_qs
from threading import Thread from threading import Thread
from urllib.parse import parse_qs, urlparse
import extensions.superboogav2.parameters as parameters
from modules import shared from modules import shared
from modules.logging_colors import logger from modules.logging_colors import logger
from .chromadb import ChromaCollector from .chromadb import ChromaCollector
from .data_processor import process_and_add_to_collector from .data_processor import process_and_add_to_collector
import extensions.superboogav2.parameters as parameters
class CustomThreadingHTTPServer(ThreadingHTTPServer): class CustomThreadingHTTPServer(ThreadingHTTPServer):
def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True): def __init__(self, server_address, RequestHandlerClass, collector: ChromaCollector, bind_and_activate=True):
@ -38,7 +37,6 @@ class Handler(BaseHTTPRequestHandler):
self.collector = collector self.collector = collector
super().__init__(request, client_address, server) super().__init__(request, client_address, server)
def _send_412_error(self, message): def _send_412_error(self, message):
self.send_response(412) self.send_response(412)
self.send_header("Content-type", "application/json") self.send_header("Content-type", "application/json")
@ -46,7 +44,6 @@ class Handler(BaseHTTPRequestHandler):
response = json.dumps({"error": message}) response = json.dumps({"error": message})
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
def _send_404_error(self): def _send_404_error(self):
self.send_response(404) self.send_response(404)
self.send_header("Content-type", "application/json") self.send_header("Content-type", "application/json")
@ -54,7 +51,6 @@ class Handler(BaseHTTPRequestHandler):
response = json.dumps({"error": "Resource not found"}) response = json.dumps({"error": "Resource not found"})
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
def _send_400_error(self, error_message: str): def _send_400_error(self, error_message: str):
self.send_response(400) self.send_response(400)
self.send_header("Content-type", "application/json") self.send_header("Content-type", "application/json")
@ -62,7 +58,6 @@ class Handler(BaseHTTPRequestHandler):
response = json.dumps({"error": error_message}) response = json.dumps({"error": error_message})
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
def _send_200_response(self, message: str): def _send_200_response(self, message: str):
self.send_response(200) self.send_response(200)
self.send_header("Content-type", "application/json") self.send_header("Content-type", "application/json")
@ -75,24 +70,21 @@ class Handler(BaseHTTPRequestHandler):
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str): def _handle_get(self, search_strings: list[str], n_results: int, max_token_count: int, sort_param: str):
if sort_param == parameters.SORT_DISTANCE: if sort_param == parameters.SORT_DISTANCE:
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count) results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
elif sort_param == parameters.SORT_ID: elif sort_param == parameters.SORT_ID:
results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count) results = self.collector.get_sorted_by_id(search_strings, n_results, max_token_count)
else: # Default is dist else: # Default is dist
results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count) results = self.collector.get_sorted_by_dist(search_strings, n_results, max_token_count)
return { return {
"results": results "results": results
} }
def do_GET(self): def do_GET(self):
self._send_404_error() self._send_404_error()
def do_POST(self): def do_POST(self):
try: try:
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
@ -146,7 +138,6 @@ class Handler(BaseHTTPRequestHandler):
except Exception as e: except Exception as e:
self._send_400_error(str(e)) self._send_400_error(str(e))
def do_DELETE(self): def do_DELETE(self):
try: try:
parsed_path = urlparse(self.path) parsed_path = urlparse(self.path)
@ -161,12 +152,10 @@ class Handler(BaseHTTPRequestHandler):
except Exception as e: except Exception as e:
self._send_400_error(str(e)) self._send_400_error(str(e))
def do_OPTIONS(self): def do_OPTIONS(self):
self.send_response(200) self.send_response(200)
self.end_headers() self.end_headers()
def end_headers(self): def end_headers(self):
self.send_header('Access-Control-Allow-Origin', '*') self.send_header('Access-Control-Allow-Origin', '*')
self.send_header('Access-Control-Allow-Methods', '*') self.send_header('Access-Control-Allow-Methods', '*')
@ -197,7 +186,7 @@ class APIManager:
def stop_server(self): def stop_server(self):
if self.server is not None: if self.server is not None:
logger.info(f'Stopping chromaDB API.') logger.info('Stopping chromaDB API.')
self.server.shutdown() self.server.shutdown()
self.server.server_close() self.server.server_close()
self.server = None self.server = None

View file

@ -9,13 +9,13 @@ The benchmark function will return the score as an integer.
import datetime import datetime
import json import json
import os import os
from pathlib import Path from pathlib import Path
from .data_processor import process_and_add_to_collector, preprocess_text from .data_processor import preprocess_text, process_and_add_to_collector
from .parameters import get_chunk_count, get_max_token_count from .parameters import get_chunk_count, get_max_token_count
from .utils import create_metadata_source from .utils import create_metadata_source
def benchmark(config_path, collector): def benchmark(config_path, collector):
# Get the current system date # Get the current system date
sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") sysdate = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

View file

@ -4,16 +4,17 @@ This module is responsible for modifying the chat prompt and history.
import re import re
import extensions.superboogav2.parameters as parameters import extensions.superboogav2.parameters as parameters
from extensions.superboogav2.utils import (
create_context_text,
create_metadata_source
)
from modules import chat, shared from modules import chat, shared
from modules.text_generation import get_encoded_length
from modules.logging_colors import logger
from modules.chat import load_character_memoized from modules.chat import load_character_memoized
from extensions.superboogav2.utils import create_context_text, create_metadata_source from modules.logging_colors import logger
from modules.text_generation import get_encoded_length
from .data_processor import process_and_add_to_collector
from .chromadb import ChromaCollector from .chromadb import ChromaCollector
from .data_processor import process_and_add_to_collector
CHAT_METADATA = create_metadata_source('automatic-chat-insert') CHAT_METADATA = create_metadata_source('automatic-chat-insert')
@ -69,7 +70,7 @@ def _concatinate_history(history: dict, state: dict):
if len(exchange) >= 2: if len(exchange) >= 2:
full_history_text += _format_single_exchange(bot_name, exchange[1]) full_history_text += _format_single_exchange(bot_name, exchange[1])
return full_history_text[:-1] # Remove the last new line. return full_history_text[:-1] # Remove the last new line.
def _hijack_last(context_text: str, history: dict, max_len: int, state: dict): def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
@ -95,7 +96,7 @@ def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
else: else:
# replace the message at replace_position with context_text # replace the message at replace_position with context_text
i, j = replace_position i, j = replace_position
history['internal'][-i-1][-j-1] = context_text history['internal'][-i - 1][-j - 1] = context_text
def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs): def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):

View file

@ -1,42 +1,23 @@
import threading
import chromadb
import posthog
import torch
import math import math
import random
import threading
import chromadb
import numpy as np import numpy as np
import extensions.superboogav2.parameters as parameters import posthog
from chromadb.config import Settings from chromadb.config import Settings
from sentence_transformers import SentenceTransformer from chromadb.utils import embedding_functions
import extensions.superboogav2.parameters as parameters
from modules.logging_colors import logger from modules.logging_colors import logger
from modules.text_generation import encode, decode from modules.text_generation import decode, encode
logger.debug('Intercepting all calls to posthog.') # Intercept calls to posthog
posthog.capture = lambda *args, **kwargs: None posthog.capture = lambda *args, **kwargs: None
class Collecter(): embedder = embedding_functions.SentenceTransformerEmbeddingFunction("sentence-transformers/all-mpnet-base-v2")
def __init__(self):
pass
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]):
pass
def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
def clear(self):
pass
class Embedder():
def __init__(self):
pass
def embed(self, text: str) -> list[torch.Tensor]:
pass
class Info: class Info:
def __init__(self, start_index, text_with_context, distance, id): def __init__(self, start_index, text_with_context, distance, id):
@ -58,7 +39,7 @@ class Info:
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY: elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
# Arithmetic mean # Arithmetic mean
return (self.distance + other_info.distance) / 2 return (self.distance + other_info.distance) / 2
else: # Min is default else: # Min is default
return min(self.distance, other_info.distance) return min(self.distance, other_info.distance)
def merge_with(self, other_info): def merge_with(self, other_info):
@ -93,16 +74,19 @@ class Info:
return not (s1_end < s2_start or s2_end < s1_start) return not (s1_end < s2_start or s2_end < s1_start)
class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder): class ChromaCollector():
super().__init__() def __init__(self):
name = ''.join(random.choice('ab') for _ in range(10))
self.name = name
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False)) self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder self.collection = self.chroma_client.create_collection(name=name, embedding_function=embedder)
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed)
self.ids = [] self.ids = []
self.id_to_info = {} self.id_to_info = {}
self.embeddings_cache = {} self.embeddings_cache = {}
self.lock = threading.Lock() # Locking so the server doesn't break. self.lock = threading.Lock() # Locking so the server doesn't break.
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None): def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
with self.lock: with self.lock:
@ -114,7 +98,7 @@ class ChromaCollector(Collecter):
new_ids = self._get_new_ids(len(texts)) new_ids = self._get_new_ids(len(texts))
(existing_texts, existing_embeddings, existing_ids, existing_metas), \ (existing_texts, existing_embeddings, existing_ids, existing_metas), \
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas) (non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
# If there are any already existing texts, add them all at once. # If there are any already existing texts, add them all at once.
if existing_texts: if existing_texts:
@ -126,7 +110,7 @@ class ChromaCollector(Collecter):
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead. # If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
if non_existing_texts: if non_existing_texts:
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist() non_existing_embeddings = embedder(non_existing_texts)
for text, embedding in zip(non_existing_texts, non_existing_embeddings): for text, embedding in zip(non_existing_texts, non_existing_embeddings):
self.embeddings_cache[text] = embedding self.embeddings_cache[text] = embedding
@ -145,7 +129,6 @@ class ChromaCollector(Collecter):
self.id_to_info.update(new_info) self.id_to_info.update(new_info)
self.ids.extend(new_ids) self.ids.extend(new_ids)
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]): def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
existing_texts, non_existing_texts = [], [] existing_texts, non_existing_texts = [], []
existing_embeddings = [] existing_embeddings = []
@ -169,7 +152,6 @@ class ChromaCollector(Collecter):
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \ return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
(non_existing_texts, non_existing_ids, non_existing_metas) (non_existing_texts, non_existing_ids, non_existing_metas)
def _get_new_ids(self, num_new_ids: int): def _get_new_ids(self, num_new_ids: int):
if self.ids: if self.ids:
max_existing_id = max(int(id_) for id_ in self.ids) max_existing_id = max(int(id_) for id_ in self.ids)
@ -178,7 +160,6 @@ class ChromaCollector(Collecter):
return [str(i + max_existing_id + 1) for i in range(num_new_ids)] return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
def _find_min_max_start_index(self): def _find_min_max_start_index(self):
max_index, min_index = 0, float('inf') max_index, min_index = 0, float('inf')
for _, val in self.id_to_info.items(): for _, val in self.id_to_info.items():
@ -188,13 +169,14 @@ class ChromaCollector(Collecter):
min_index = val['start_index'] min_index = val['start_index']
return min_index, max_index return min_index, max_index
# NB: Does not make sense to weigh excerpts from different documents. # NB: Does not make sense to weigh excerpts from different documents.
# But let's say that's the user's problem. Perfect world scenario: # But let's say that's the user's problem. Perfect world scenario:
# Apply time weighing to different documents. For each document, then, add # Apply time weighing to different documents. For each document, then, add
# separate time weighing. # separate time weighing.
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float): def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
sigmoid = lambda x: 1 / (1 + np.exp(-x)) def sigmoid(x):
return 1 / (1 + np.exp(-x))
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len)) weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
@ -210,7 +192,6 @@ class ChromaCollector(Collecter):
index = info.start_index index = info.start_index
info.distance *= weights[index] info.distance *= weights[index]
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float): def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
# Ensure there are infos to filter # Ensure there are infos to filter
if not infos: if not infos:
@ -231,7 +212,6 @@ class ChromaCollector(Collecter):
return filtered_infos return filtered_infos
def _merge_infos(self, infos: list[Info]): def _merge_infos(self, infos: list[Info]):
merged_infos = [] merged_infos = []
current_info = infos[0] current_info = infos[0]
@ -247,8 +227,8 @@ class ChromaCollector(Collecter):
merged_infos.append(current_info) merged_infos.append(current_info)
return merged_infos return merged_infos
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering. # Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int): def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
n_results = min(len(self.ids), n_results) n_results = min(len(self.ids), n_results)
if n_results == 0: if n_results == 0:
@ -280,22 +260,22 @@ class ChromaCollector(Collecter):
return texts_with_context, ids, distances return texts_with_context, ids, distances
# Get chunks by similarity # Get chunks by similarity
def get(self, search_strings: list[str], n_results: int) -> list[str]: def get(self, search_strings: list[str], n_results: int) -> list[str]:
with self.lock: with self.lock:
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results) documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
return documents return documents
# Get ids by similarity # Get ids by similarity
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
with self.lock: with self.lock:
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results) _, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
return ids return ids
# Cutoff token count # Cutoff token count
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int): def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
# TODO: Move to caller; We add delimiters there which might go over the limit. # TODO: Move to caller; We add delimiters there which might go over the limit.
current_token_count = 0 current_token_count = 0
@ -318,8 +298,8 @@ class ChromaCollector(Collecter):
return return_documents return return_documents
# Get chunks by similarity and then sort by ids # Get chunks by similarity and then sort by ids
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]: def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
with self.lock: with self.lock:
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results) documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
@ -327,20 +307,19 @@ class ChromaCollector(Collecter):
return self._get_documents_up_to_token_count(sorted_docs, max_token_count) return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
# Get chunks by similarity and then sort by distance (lowest distance is last). # Get chunks by similarity and then sort by distance (lowest distance is last).
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]: def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
with self.lock: with self.lock:
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results) documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
# If a document is truncated or competely skipped, it would be with high distance. # If a document is truncated or competely skipped, it would be with high distance.
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count) return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
return_documents.reverse() # highest -> lowest return_documents.reverse() # highest -> lowest
return return_documents return return_documents
def delete(self, ids_to_delete: list[str], where: dict): def delete(self, ids_to_delete: list[str], where: dict):
with self.lock: with self.lock:
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids'] ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
@ -354,23 +333,16 @@ class ChromaCollector(Collecter):
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.') logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
def clear(self): def clear(self):
with self.lock: with self.lock:
self.chroma_client.reset() self.chroma_client.reset()
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed)
self.ids = [] self.ids = []
self.id_to_info = {} self.chroma_client.delete_collection(name=self.name)
self.collection = self.chroma_client.create_collection(name=self.name, embedding_function=embedder)
logger.info('Successfully cleared all records and reset chromaDB.') logger.info('Successfully cleared all records and reset chromaDB.')
class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
logger.debug('Creating Sentence Embedder...')
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
def make_collector(): def make_collector():
return ChromaCollector(SentenceTransformerEmbedder()) return ChromaCollector()

View file

@ -11,20 +11,19 @@ This module contains utils for preprocessing the text before converting it to em
* removing specific parts of speech (adverbs and interjections) * removing specific parts of speech (adverbs and interjections)
- TextSummarizer extracts the most important sentences from a long string using text-ranking. - TextSummarizer extracts the most important sentences from a long string using text-ranking.
""" """
import pytextrank
import string
import spacy
import math import math
import nltk
import re import re
import string
import nltk
import spacy
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer from nltk.stem import WordNetLemmatizer
from num2words import num2words from num2words import num2words
class TextPreprocessorBuilder: class TextPreprocessorBuilder:
# Define class variables as None initially # Define class variables as None initially
_stop_words = set(stopwords.words('english')) _stop_words = set(stopwords.words('english'))
_lemmatizer = WordNetLemmatizer() _lemmatizer = WordNetLemmatizer()
@ -32,11 +31,9 @@ class TextPreprocessorBuilder:
_lemmatizer_cache = {} _lemmatizer_cache = {}
_pos_remove_cache = {} _pos_remove_cache = {}
def __init__(self, text: str): def __init__(self, text: str):
self.text = text self.text = text
def to_lower(self): def to_lower(self):
# Match both words and non-word characters # Match both words and non-word characters
tokens = re.findall(r'\b\w+\b|\W+', self.text) tokens = re.findall(r'\b\w+\b|\W+', self.text)
@ -49,7 +46,6 @@ class TextPreprocessorBuilder:
self.text = "".join(tokens) self.text = "".join(tokens)
return self return self
def num_to_word(self, min_len: int = 1): def num_to_word(self, min_len: int = 1):
# Match both words and non-word characters # Match both words and non-word characters
tokens = re.findall(r'\b\w+\b|\W+', self.text) tokens = re.findall(r'\b\w+\b|\W+', self.text)
@ -58,11 +54,10 @@ class TextPreprocessorBuilder:
if token.isdigit() and len(token) >= min_len: if token.isdigit() and len(token) >= min_len:
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers) # This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
# 740700 will become "seven hundred and forty thousand seven hundred". # 740700 will become "seven hundred and forty thousand seven hundred".
tokens[i] = num2words(int(token)).replace(",","") # Remove commas from num2words. tokens[i] = num2words(int(token)).replace(",", "") # Remove commas from num2words.
self.text = "".join(tokens) self.text = "".join(tokens)
return self return self
def num_to_char_long(self, min_len: int = 1): def num_to_char_long(self, min_len: int = 1):
# Match both words and non-word characters # Match both words and non-word characters
tokens = re.findall(r'\b\w+\b|\W+', self.text) tokens = re.findall(r'\b\w+\b|\W+', self.text)
@ -71,7 +66,9 @@ class TextPreprocessorBuilder:
if token.isdigit() and len(token) >= min_len: if token.isdigit() and len(token) >= min_len:
# This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers) # This is done to pay better attention to numbers (e.g. ticket numbers, thread numbers, post numbers)
# 740700 will become HHHHHHEEEEEAAAAHHHAAA # 740700 will become HHHHHHEEEEEAAAAHHHAAA
convert_token = lambda token: ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1] def convert_token(token):
return ''.join((chr(int(digit) + 65) * (i + 1)) for i, digit in enumerate(token[::-1]))[::-1]
tokens[i] = convert_token(tokens[i]) tokens[i] = convert_token(tokens[i])
self.text = "".join(tokens) self.text = "".join(tokens)
return self return self
@ -150,6 +147,7 @@ class TextPreprocessorBuilder:
def build(self): def build(self):
return self.text return self.text
class TextSummarizer: class TextSummarizer:
_nlp_pipeline = None _nlp_pipeline = None
_cache = {} _cache = {}

View file

@ -4,13 +4,14 @@ It will then split it into chunks of specified length. For each of those chunks,
It will only include full words. It will only include full words.
""" """
import re
import bisect import bisect
import re
import extensions.superboogav2.parameters as parameters import extensions.superboogav2.parameters as parameters
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
from .chromadb import ChromaCollector from .chromadb import ChromaCollector
from .data_preprocessor import TextPreprocessorBuilder, TextSummarizer
def preprocess_text_no_summary(text) -> str: def preprocess_text_no_summary(text) -> str:
builder = TextPreprocessorBuilder(text) builder = TextPreprocessorBuilder(text)
@ -125,7 +126,7 @@ def _clear_chunks(data_chunks, data_chunks_with_context, data_chunk_starting_ind
seen_chunk_start = seen_chunks.get(chunk) seen_chunk_start = seen_chunks.get(chunk)
if seen_chunk_start: if seen_chunk_start:
# If we've already seen this exact chunk, and the context around it it very close to the seen chunk, then skip it. # If we've already seen this exact chunk, and the context around it it very close to the seen chunk, then skip it.
if abs(seen_chunk_start-index) < parameters.get_delta_start(): if abs(seen_chunk_start - index) < parameters.get_delta_start():
continue continue
distinct_data_chunks.append(chunk) distinct_data_chunks.append(chunk)
@ -206,4 +207,4 @@ def process_and_add_to_collector(corpus: str, collector: ChromaCollector, clear_
if clear_collector_before_adding: if clear_collector_before_adding:
collector.clear() collector.clear()
collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata]*len(data_chunks) if metadata is not None else None) collector.add(data_chunks, data_chunks_with_context, data_chunk_starting_indices, [metadata] * len(data_chunks) if metadata is not None else None)

View file

@ -1,7 +1,7 @@
import concurrent.futures import concurrent.futures
import requests
import re import re
import requests
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import extensions.superboogav2.parameters as parameters import extensions.superboogav2.parameters as parameters
@ -9,6 +9,7 @@ import extensions.superboogav2.parameters as parameters
from .data_processor import process_and_add_to_collector from .data_processor import process_and_add_to_collector
from .utils import create_metadata_source from .utils import create_metadata_source
def _download_single(url): def _download_single(url):
response = requests.get(url, timeout=5) response = requests.get(url, timeout=5)
if response.status_code == 200: if response.status_code == 200:

View file

@ -4,13 +4,12 @@ This module is responsible for handling and modifying the notebook text.
import re import re
import extensions.superboogav2.parameters as parameters import extensions.superboogav2.parameters as parameters
from modules import shared
from modules.logging_colors import logger
from extensions.superboogav2.utils import create_context_text from extensions.superboogav2.utils import create_context_text
from modules.logging_colors import logger
from .data_processor import preprocess_text from .data_processor import preprocess_text
def _remove_special_tokens(string): def _remove_special_tokens(string):
pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)' pattern = r'(<\|begin-user-input\|>|<\|end-user-input\|>|<\|injection-point\|>)'
return re.sub(pattern, '', string) return re.sub(pattern, '', string)

View file

@ -3,22 +3,24 @@ This module implements a hyperparameter optimization routine for the embedding a
Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found. Each run, the optimizer will set the default values inside the hyperparameters. At the end, it will output the best ones it has found.
""" """
import re import hashlib
import json import json
import optuna import logging
import re
import gradio as gr import gradio as gr
import numpy as np import numpy as np
import logging import optuna
import hashlib
logging.getLogger('optuna').setLevel(logging.WARNING)
import extensions.superboogav2.parameters as parameters logging.getLogger('optuna').setLevel(logging.WARNING)
from pathlib import Path from pathlib import Path
import extensions.superboogav2.parameters as parameters
from modules.logging_colors import logger
from .benchmark import benchmark from .benchmark import benchmark
from .parameters import Parameters from .parameters import Parameters
from modules.logging_colors import logger
# Format the parameters into markdown format. # Format the parameters into markdown format.
@ -55,7 +57,7 @@ def _set_hyperparameters(params):
# Check if the parameter is for optimization. # Check if the parameter is for optimization.
def _is_optimization_param(val): def _is_optimization_param(val):
is_opt = val.get('should_optimize', False) # Either does not exist or is false is_opt = val.get('should_optimize', False) # Either does not exist or is false
return is_opt return is_opt
@ -67,7 +69,7 @@ def _get_params_hash(params):
def optimize(collector, progress=gr.Progress()): def optimize(collector, progress=gr.Progress()):
# Inform the user that something is happening. # Inform the user that something is happening.
progress(0, desc=f'Setting Up...') progress(0, desc='Setting Up...')
# Track the current step # Track the current step
current_step = 0 current_step = 0

View file

@ -6,13 +6,11 @@ Each element in the JSON must have a `default` value which will be used for the
These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`, These categories define the range in which the optimizer will search. If the element is tagged with `"should_optimize": false`,
then the optimizer will only ever use the default value. then the optimizer will only ever use the default value.
""" """
import json
from pathlib import Path from pathlib import Path
import json
from modules.logging_colors import logger from modules.logging_colors import logger
NUM_TO_WORD_METHOD = 'Number to Word' NUM_TO_WORD_METHOD = 'Number to Word'
NUM_TO_CHAR_METHOD = 'Number to Char' NUM_TO_CHAR_METHOD = 'Number to Char'
NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char' NUM_TO_CHAR_LONG_METHOD = 'Number to Multi-Char'

View file

@ -1,5 +1,5 @@
beautifulsoup4==4.12.2 beautifulsoup4==4.12.2
chromadb==0.3.18 chromadb==0.4.24
lxml lxml
optuna optuna
pandas==2.0.3 pandas==2.0.3

View file

@ -7,28 +7,29 @@ from pathlib import Path
# Point to where nltk will find the required data. # Point to where nltk will find the required data.
os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve()) os.environ['NLTK_DATA'] = str(Path("extensions/superboogav2/nltk_data").resolve())
import textwrap
import codecs import codecs
import textwrap
import gradio as gr import gradio as gr
import extensions.superboogav2.parameters as parameters import extensions.superboogav2.parameters as parameters
from modules.logging_colors import logger
from modules import shared from modules import shared
from modules.logging_colors import logger
from .utils import create_metadata_source
from .chromadb import make_collector
from .download_urls import feed_url_into_collector
from .data_processor import process_and_add_to_collector
from .benchmark import benchmark
from .optimize import optimize
from .notebook_handler import input_modifier_internal
from .chat_handler import custom_generate_chat_prompt_internal
from .api import APIManager from .api import APIManager
from .benchmark import benchmark
from .chat_handler import custom_generate_chat_prompt_internal
from .chromadb import make_collector
from .data_processor import process_and_add_to_collector
from .download_urls import feed_url_into_collector
from .notebook_handler import input_modifier_internal
from .optimize import optimize
from .utils import create_metadata_source
collector = None collector = None
api_manager = None api_manager = None
def setup(): def setup():
global collector global collector
global api_manager global api_manager
@ -38,6 +39,7 @@ def setup():
if parameters.get_api_on(): if parameters.get_api_on():
api_manager.start_server(parameters.get_api_port()) api_manager.start_server(parameters.get_api_port())
def _feed_data_into_collector(corpus): def _feed_data_into_collector(corpus):
yield '### Processing data...' yield '### Processing data...'
process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text')) process_and_add_to_collector(corpus, collector, False, create_metadata_source('direct-text'))
@ -305,10 +307,8 @@ def ui():
optimize_button = gr.Button('Optimize') optimize_button = gr.Button('Optimize')
optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True) optimization_steps = gr.Number(value=parameters.get_optimization_steps(), label='Optimization Steps', info='For how many steps to optimize.', interactive=True)
clear_button = gr.Button('❌ Clear Data') clear_button = gr.Button('❌ Clear Data')
with gr.Column(): with gr.Column():
last_updated = gr.Markdown() last_updated = gr.Markdown()
@ -316,8 +316,7 @@ def ui():
preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count, preprocess_pipeline, api_port, api_on, injection_strategy, add_chat_to_data, manual, postfix, data_separator, prefix, max_token_count,
chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup] chunk_count, chunk_sep, context_len, chunk_regex, chunk_len, threads, strong_cleanup]
optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion, optimizable_params = [time_power, time_steepness, significant_level, min_sentences, new_dist_strat, delta_start, min_number_length, num_conversion,
preprocess_pipeline, chunk_count, context_len, chunk_len] preprocess_pipeline, chunk_count, context_len, chunk_len]
update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False) update_data.click(_feed_data_into_collector, [data_input], last_updated, show_progress=False)
update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False) update_url.click(_feed_url_into_collector, [url_input], last_updated, show_progress=False)
@ -326,7 +325,6 @@ def ui():
optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True) optimize_button.click(_begin_optimization, [], [last_updated] + optimizable_params, show_progress=True)
clear_button.click(_clear_data, [], last_updated, show_progress=False) clear_button.click(_clear_data, [], last_updated, show_progress=False)
optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False) optimization_steps.input(fn=_apply_settings, inputs=all_params, show_progress=False)
time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False) time_power.input(fn=_apply_settings, inputs=all_params, show_progress=False)
time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False) time_steepness.input(fn=_apply_settings, inputs=all_params, show_progress=False)

View file

@ -4,6 +4,7 @@ This module contains common functions across multiple other modules.
import extensions.superboogav2.parameters as parameters import extensions.superboogav2.parameters as parameters
# Create the context using the prefix + data_separator + postfix from parameters. # Create the context using the prefix + data_separator + postfix from parameters.
def create_context_text(results): def create_context_text(results):
context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix() context = parameters.get_prefix() + parameters.get_data_separator().join(results) + parameters.get_postfix()