Reorganize superbig ui

This commit is contained in:
oobabooga 2023-05-07 11:30:16 -03:00
parent befa307c42
commit a35a2fab02

View file

@ -68,9 +68,8 @@ collector = ChromaCollector(embedder)
chunk_count = 5
def feed_data_into_collector(corpus, chunk_len, _chunk_count):
global collector, chunk_count
chunk_count = int(_chunk_count)
def feed_data_into_collector(corpus, chunk_len):
global collector
chunk_len = int(chunk_len)
cumulative = ''
@ -85,14 +84,14 @@ def feed_data_into_collector(corpus, chunk_len, _chunk_count):
yield cumulative
def feed_file_into_collector(file, chunk_len, chunk_count):
def feed_file_into_collector(file, chunk_len):
yield 'Reading the input dataset...\n\n'
text = file.decode('utf-8')
for i in feed_data_into_collector(text, chunk_len, chunk_count):
for i in feed_data_into_collector(text, chunk_len):
yield i
def feed_url_into_collector(urls, chunk_len, chunk_count):
def feed_url_into_collector(urls, chunk_len):
urls = urls.strip().split('\n')
all_text = ''
cumulative = ''
@ -110,10 +109,19 @@ def feed_url_into_collector(urls, chunk_len, chunk_count):
text = '\n\n'.join(chunk for chunk in chunks if chunk)
all_text += text
for i in feed_data_into_collector(all_text, chunk_len, chunk_count):
for i in feed_data_into_collector(all_text, chunk_len):
yield i
def apply_settings(_chunk_count):
global chunk_count
chunk_count = _chunk_count
settings_to_display = {
'chunk_count': int(chunk_count),
}
yield f"The following settings are now active: {str(settings_to_display)}"
def input_modifier(string):
# Find the user input
@ -135,20 +143,27 @@ def input_modifier(string):
return string
def ui():
with gr.Accordion("Click for more information...", open=False):
gr.Markdown(textwrap.dedent("""
*This extension is currently experimental and under development.*
## About
This extension takes a dataset as input, breaks it into chunks, and adds the result to a local/offline Chroma database.
The database is then queried during inference time to get the excerpts that are closest to your input. The idea is to create
an arbitrarily large pseudocontext.
## How to use it
1) Paste your input text (of whatever length) into the text box below.
2) Click on the "Apply" button located below the text box
3) In your prompt, enter your question between <|begin-user-input|> and <|end-user-input|>, and specify the injection point with <|injection-point|>
2) Click on "Load data" to feed this text into the Chroma database.
3) In your prompt, enter your question between `<|begin-user-input|>` and `<|end-user-input|>`, and specify the injection point with `<|injection-point|>`.
## How it works
By default, the 5 closest chunks will be injected. You can customize this value in the "Generation settings" tab.
In the background, the 5 chunks in the input text most similar to the user input will be placed at the injection point, and the special tokens above will be removed. Then the text generation will proceed as usual.
The special tokens mentioned above (`<|begin-user-input|>`, `<|end-user-input|>`, and `<|injection-point|>`) are removed when the injection happens.
## Example
@ -168,7 +183,10 @@ def ui():
### Response:
```
*This extension is currently experimental and under development.*
"""))
if shared.is_chat():
# Chat mode has to be handled differently, probably using a custom_generate_chat_prompt
pass
@ -177,23 +195,26 @@ def ui():
with gr.Column():
with gr.Tab("Text input"):
data_input = gr.Textbox(lines=20, label='Input data')
update_data = gr.Button('Apply')
update_data = gr.Button('Load data')
with gr.Tab("URL input"):
url_input = gr.Textbox(lines=10, label='Input URL', info='Enter one or more URLs separated by newline characters')
update_url = gr.Button('Apply')
url_input = gr.Textbox(lines=10, label='Input URLs', info='Enter one or more URLs separated by newline characters.')
update_url = gr.Button('Load data')
with gr.Tab("File input"):
file_input = gr.File(label='Input file', type='binary')
update_file = gr.Button('Apply')
update_file = gr.Button('Load data')
with gr.Row():
chunk_len = gr.Number(value=700, label='Chunk length', info='In characters, not tokens')
chunk_count = gr.Number(value=5, label='Chunk count', info='The number of closest-matching chunks to include in the prompt')
with gr.Tab("Generation settings"):
chunk_count = gr.Number(value=5, label='Chunk count', info='The number of closest-matching chunks to include in the prompt.')
update_settings = gr.Button('Apply changes')
chunk_len = gr.Number(value=700, label='Chunk length', info='In characters, not tokens. This value is used when you click on "Load data".')
with gr.Column():
last_updated = gr.Markdown()
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_count], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_count], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_count], last_updated, show_progress=False)
update_data.click(feed_data_into_collector, [data_input, chunk_len], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, [url_input, chunk_len], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, [file_input, chunk_len], last_updated, show_progress=False)
update_settings.click(apply_settings, [chunk_count], last_updated, show_progress=False)