Make chunk length/count customizable

This commit is contained in:
oobabooga 2023-05-07 05:02:04 -03:00
parent 8c06eeaf84
commit 04eca9b65b

View file

@ -65,13 +65,15 @@ class SentenceTransformerEmbedder(Embedder):
embedder = SentenceTransformerEmbedder()
collector = ChromaCollector(embedder)
chunk_count = 5
def feed_data_into_collector(corpus):
global collector
def feed_data_into_collector(corpus, chunk_len, _chunk_count):
global collector, chunk_count
chunk_count = int(_chunk_count)
chunk_len = int(chunk_len)
cumulative = ''
chunk_len = 700
cumulative += "Breaking the input dataset...\n\n"
yield cumulative
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
@ -83,14 +85,14 @@ def feed_data_into_collector(corpus):
yield cumulative
def feed_file_into_collector(file):
def feed_file_into_collector(file, chunk_len, chunk_count):
yield 'Reading the input dataset...\n\n'
text = file.decode('utf-8')
for i in feed_data_into_collector(text):
for i in feed_data_into_collector(text, chunk_len, chunk_count):
yield i
def feed_url_into_collector(url):
def feed_url_into_collector(url, chunk_len, chunk_count):
yield 'Loading the URL...'
html = urlopen(url).read()
soup = BeautifulSoup(html, features="html.parser")
@ -101,7 +103,7 @@ def feed_url_into_collector(url):
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n\n'.join(chunk for chunk in chunks if chunk)
for i in feed_data_into_collector(text):
for i in feed_data_into_collector(text, chunk_len, chunk_count):
yield i
@ -115,8 +117,8 @@ def input_modifier(string):
else:
user_input = ''
# Get the 5 most similar chunks
results = collector.get(user_input, n_results=5)
# Get the most similar chunks
results = collector.get(user_input, n_results=chunk_count)
# Make the replacements
string = string.replace('<|begin-user-input|>', '')
@ -178,9 +180,13 @@ def ui():
file_input = gr.File(label='Input file', type='binary')
update_file = gr.Button('Apply')
with gr.Row():
chunk_len = gr.Number(value=700, label='Chunk length', info='In characters, not tokens')
chunk_count = gr.Number(value=5, label='Chunk count', info='The number of closest-matching chunks to include in the prompt')
with gr.Column():
last_updated = gr.Markdown()
update_data.click(feed_data_into_collector, data_input, last_updated, show_progress=False)
update_url.click(feed_url_into_collector, url_input, last_updated, show_progress=False)
update_file.click(feed_file_into_collector, file_input, last_updated, show_progress=False)
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_count], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_count], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_count], last_updated, show_progress=False)