Refactor several function calls and the API

This commit is contained in:
oobabooga 2023-04-06 01:22:15 -03:00 committed by GitHub
parent 378d21e80c
commit 3f3e42e26c
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 147 additions and 118 deletions

View file

@ -36,6 +36,7 @@ async def run(context):
'early_stopping': False,
'seed': -1,
}
payload = json.dumps([context, params])
session = random_hash()
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
@ -54,22 +55,7 @@ async def run(context):
"session_hash": session,
"fn_index": 12,
"data": [
context,
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
payload
]
}))
case "process_starts":

View file

@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
allowing you to use the API remotely.
'''
import json
import requests
# Server address
@ -38,24 +40,11 @@ params = {
# Input prompt
prompt = "What I would like to say is the following: "
payload = json.dumps([prompt, params])
response = requests.post(f"http://{server}:7860/run/textgen", json={
"data": [
prompt,
params['max_new_tokens'],
params['do_sample'],
params['temperature'],
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],
params['num_beams'],
params['penalty_alpha'],
params['length_penalty'],
params['early_stopping'],
params['seed'],
payload
]
}).json()

View file

@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler):
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_length', 200)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical', 1)),
'repetition_penalty': float(body.get('rep_pen', 1.1)),
'encoder_repetition_penalty': 1,
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
'num_beams': int(body.get('num_beams',1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
}
generator = generate_reply(
question = prompt,
max_new_tokens = int(body.get('max_length', 200)),
do_sample=bool(body.get('do_sample', True)),
temperature=float(body.get('temperature', 0.5)),
top_p=float(body.get('top_p', 1)),
typical_p=float(body.get('typical', 1)),
repetition_penalty=float(body.get('rep_pen', 1.1)),
encoder_repetition_penalty=1,
top_k=int(body.get('top_k', 0)),
min_length=int(body.get('min_length', 0)),
no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
num_beams=int(body.get('num_beams',1)),
penalty_alpha=float(body.get('penalty_alpha', 0)),
length_penalty=float(body.get('length_penalty', 1)),
early_stopping=bool(body.get('early_stopping', False)),
seed=int(body.get('seed', -1)),
prompt,
generate_params,
stopping_strings=body.get('stopping_strings', []),
)

View file

@ -2,12 +2,11 @@ import base64
from io import BytesIO
import gradio as gr
import modules.chat as chat
import modules.shared as shared
import torch
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor
from modules import chat, shared
# If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text]
input_hijack = {

38
modules/api.py Normal file
View file

@ -0,0 +1,38 @@
import json
import gradio as gr
from modules import shared
from modules.text_generation import generate_reply
def generate_reply_wrapper(string):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 50,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
}
params = json.loads(string)
for k in params[1]:
generate_params[k] = params[1][k]
for i in generate_reply(params[0], generate_params):
yield i
def create_apis():
t1 = gr.Textbox(visible=False)
t2 = gr.Textbox(visible=False)
dummy = gr.Button(visible=False)
input_params = [t1]
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

View file

@ -18,7 +18,12 @@ from modules.text_generation import (encode, generate_reply,
get_max_prompt_length)
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn="", impersonate=False, also_return_rows=False):
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"]
@ -91,9 +96,9 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
reply = fix_newlines(reply)
return reply, next_character_found
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False, mode="cai-chat", end_of_turn=""):
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
just_started = True
eos_token = '\n' if stop_at_newline else None
eos_token = '\n' if generate_state['stop_at_newline'] else None
name1_original = name1
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
@ -112,11 +117,11 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
visible_text = text
text = apply_extensions(text, "input")
is_instruct = mode == 'instruct'
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn)
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
else:
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn)
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
# Yield *Is typing...*
if not regenerate:
@ -124,13 +129,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
# Generate
cumulative_reply = ''
for i in range(chat_generation_attempts):
for i in range(generate_state['chat_generation_attempts']):
reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply
# Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
visible_reply = apply_extensions(visible_reply, "output")
@ -155,23 +160,23 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
yield shared.history['visible']
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""):
eos_token = '\n' if stop_at_newline else None
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
eos_token = '\n' if generate_state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True, end_of_turn=end_of_turn)
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
# Yield *Is typing...*
yield shared.processing_message
cumulative_reply = ''
for i in range(chat_generation_attempts):
for i in range(generate_state['chat_generation_attempts']):
reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
yield reply
if next_character_found:
break
@ -181,11 +186,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
yield reply
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""):
for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=False, mode=mode, end_of_turn=end_of_turn):
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
yield chat_html_wrapper(history, name1, name2, mode)
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""):
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
else:
@ -193,7 +198,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*'
yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True, mode=mode, end_of_turn=end_of_turn):
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)

View file

@ -102,10 +102,11 @@ def set_manual_seed(seed):
def stop_everything_event():
shared.stop_everything = True
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
clear_torch_cache()
set_manual_seed(seed)
set_manual_seed(generate_state['seed'])
shared.stop_everything = False
generate_params = {}
t0 = time.time()
original_question = question
@ -117,9 +118,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)):
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = generate_state[k]
generate_params["token_count"] = generate_state["max_new_tokens"]
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
reply = shared.model.generate(context=question, **generate_params)
output = original_question+reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output")
@ -130,7 +134,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question+reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, "output")
@ -145,7 +149,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
return
input_ids = encode(question, max_new_tokens)
input_ids = encode(question, generate_state['max_new_tokens'])
original_input_ids = input_ids
output = input_ids[0]
@ -158,33 +162,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
generate_params = {}
generate_params["max_new_tokens"] = generate_state['max_new_tokens']
if not shared.args.flexgen:
generate_params.update({
"max_new_tokens": max_new_tokens,
"eos_token_id": eos_token_ids,
"stopping_criteria": stopping_criteria_list,
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"typical_p": typical_p,
"repetition_penalty": repetition_penalty,
"encoder_repetition_penalty": encoder_repetition_penalty,
"top_k": top_k,
"min_length": min_length if shared.args.no_stream else 0,
"no_repeat_ngram_size": no_repeat_ngram_size,
"num_beams": num_beams,
"penalty_alpha": penalty_alpha,
"length_penalty": length_penalty,
"early_stopping": early_stopping,
})
for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
generate_params[k] = generate_state[k]
generate_params["eos_token_id"] = eos_token_ids
generate_params["stopping_criteria"] = stopping_criteria_list
if shared.args.no_stream:
generate_params["min_length"] = 0
else:
generate_params.update({
"max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
"do_sample": do_sample,
"temperature": temperature,
"stop": eos_token_ids[-1],
})
for k in ["do_sample", "temperature"]:
generate_params[k] = generate_state[k]
generate_params["stop"] = generate_state["eos_token_ids"][-1]
if not shared.args.no_stream:
generate_params["max_new_tokens"] = 8
if shared.args.no_cache:
generate_params.update({"use_cache": False})
if shared.args.deepspeed:
@ -244,7 +236,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(max_new_tokens//8+1):
for i in range(generate_state['max_new_tokens']//8+1):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]

View file

@ -15,7 +15,7 @@ import gradio as gr
from PIL import Image
import modules.extensions as extensions_module
from modules import chat, shared, training, ui
from modules import chat, shared, training, ui, api
from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt
@ -85,7 +85,7 @@ def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora)
return selected_lora
def load_preset_values(preset_menu, return_dict=False):
def load_preset_values(preset_menu, state, return_dict=False):
generate_params = {
'do_sample': True,
'temperature': 1,
@ -107,13 +107,13 @@ def load_preset_values(preset_menu, return_dict=False):
i = i.rstrip(',').strip().split('=')
if len(i) == 2 and i[0].strip() != 'tokens':
generate_params[i[0].strip()] = eval(i[1].strip())
generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict:
return generate_params
else:
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
@ -170,7 +170,10 @@ def create_prompt_menus():
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
generate_params[k] = shared.settings[k]
shared.gradio['generate_state'] = gr.State(generate_params)
with gr.Row():
with gr.Column():
@ -221,17 +224,16 @@ def create_settings_menus(default_preset):
with gr.Row():
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args)
bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
#int_list = [k for k in cmd_list if type(k) is int]
shared.args.extensions = extensions
for k in modes[1:]:
@ -372,11 +374,11 @@ def create_interface():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts', 'Chat mode', 'end_of_turn']]
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
def set_chat_input(textbox):
return textbox, ""
@ -456,9 +458,9 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@ -489,9 +491,9 @@ def create_interface():
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
@ -524,6 +526,21 @@ def create_interface():
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
def change_dict_value(d, key, value):
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
else:
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
if not shared.is_chat():
api.create_apis()
# Authentication
auth = None
if shared.args.gradio_auth_path is not None: