Refactor text_generation.py, add support for custom generation functions (#1817)

This commit is contained in:
oobabooga 2023-05-05 18:53:03 -03:00 committed by GitHub
parent 876fbb97c0
commit 8aafb1f796
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 289 additions and 195 deletions

View file

@ -45,7 +45,9 @@ Most of these have been created by the extremely talented contributors that you
| `def ui()` | Creates custom gradio elements when the UI is launched. |
| `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. |
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def state_modifier(state)` | Modifies the dictionary containing the input parameters before it is used by the text generation functions. |
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
| `def custom_generate_reply(...)` | Overrides the main text generation function. |
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
@ -104,6 +106,23 @@ python server.py --extensions enthusiasm translate # First apply enthusiasm, the
python server.py --extensions translate enthusiasm # First apply translate, then enthusiasm
```
## `custom_generate_reply` example
Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet.
```python
import datetime
def custom_generate_reply(question, original_question, seed, state, eos_token, stopping_strings):
cumulative = ''
for i in range(10):
cumulative += f"Counting: {i}...\n"
yield cumulative
cumulative += f"Done! {str(datetime.datetime.now())}"
yield cumulative
```
## `custom_generate_chat_prompt` example
Below is an extension that just reproduces the default prompt generator in `modules/chat.py`. You can modify it freely to come up with your own prompts in chat mode.
@ -114,51 +133,64 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
is_instruct = state['mode'] == 'instruct'
rows = [f"{state['context'].strip()}\n"]
rows = [state['context'] if is_instruct else f"{state['context'].strip()}\n"]
min_rows = 3
# Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
if is_instruct:
prefix1 = f"{state['name1']}\n"
prefix2 = f"{state['name2']}\n"
# Building the turn templates
if 'turn_template' not in state or state['turn_template'] == '':
if is_instruct:
template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n'
else:
template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n'
else:
prefix1 = f"{state['name1']}: "
prefix2 = f"{state['name2']}: "
template = state['turn_template'].replace(r'\n', '\n')
replacements = {
'<|user|>': state['name1'].strip(),
'<|bot|>': state['name2'].strip(),
}
user_turn = replace_all(template.split('<|bot|>')[0], replacements)
bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements)
user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements)
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
# Building the prompt
i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
i -= 1
if impersonate:
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
limit = 2
elif _continue:
limit = 3
else:
min_rows = 2
rows.append(user_turn_stripped.rstrip(' '))
elif not _continue:
# Adding the user message
user_input = fix_newlines(user_input)
if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
# Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
limit = 3
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length:
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
rows.pop(1)
prompt = ''.join(rows)
prompt = ''.join(rows)
if also_return_rows:
return prompt, rows
else:

View file

@ -33,6 +33,7 @@ class Handler(BaseHTTPRequestHandler):
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = False
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
@ -66,7 +67,7 @@ class Handler(BaseHTTPRequestHandler):
self.send_error(404)
def _run_server(port: int, share: bool=False):
def _run_server(port: int, share: bool = False):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)

View file

@ -23,6 +23,7 @@ async def _handle_connection(websocket, path):
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)

View file

@ -18,15 +18,8 @@ wav_idx = 0
user = ElevenLabsUser(params['api_key'])
user_info = None
if not shared.args.no_stream:
print("Please add --no-stream. This extension is not meant to be used with streaming.")
raise ValueError
# Check if the API is valid and refresh the UI accordingly.
def check_valid_api():
global user, user_info, params
user = ElevenLabsUser(params['api_key'])
@ -41,9 +34,8 @@ def check_valid_api():
print('Got an API Key!')
return gr.update(value='Connected')
# Once the API is verified, get the available voices and update the dropdown list
def refresh_voices():
global user, user_info
@ -63,6 +55,11 @@ def remove_surrounded_chars(string):
return re.sub('\*[^\*]*?(\*|$)', '', string)
def state_modifier(state):
state['stream'] = False
return state
def input_modifier(string):
"""
This function is applied to your text inputs before
@ -109,6 +106,7 @@ def ui():
with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
connection_status = gr.Textbox(value='Disconnected', label='Connection Status')
voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice')
with gr.Row():
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')

View file

@ -266,8 +266,7 @@ class Handler(BaseHTTPRequestHandler):
stopping_strings += standard_stopping_strings
req_params['custom_stopping_strings'] = stopping_strings
shared.args.no_stream = not req_params['stream']
if not shared.args.no_stream:
if req_params['stream']:
shared.args.chat = True
# begin streaming
chunk = {
@ -337,7 +336,7 @@ class Handler(BaseHTTPRequestHandler):
if buffer_and_continue:
continue
if not shared.args.no_stream:
if req_params['stream']:
# Streaming
new_content = answer[len_seen:]
@ -365,7 +364,7 @@ class Handler(BaseHTTPRequestHandler):
self.wfile.write(response.encode('utf-8'))
completion_token_count += len(encode(new_content)[0])
if not shared.args.no_stream:
if req_params['stream']:
chunk = {
"id": cmpl_id,
"object": stream_object_type,

View file

@ -75,7 +75,6 @@ if params['manage_VRAM']:
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture
def remove_surrounded_chars(string):
@ -92,6 +91,13 @@ def triggers_are_in(string):
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
def state_modifier(state):
if picture_response:
state['stream'] = False
return state
def input_modifier(string):
"""
This function is applied to your text inputs before
@ -218,14 +224,13 @@ def bot_prefix_modifier(string):
def toggle_generation(*args):
global picture_response, shared, streaming_state
global picture_response, shared
if not args:
picture_response = not picture_response
else:
picture_response = args[0]
shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"

View file

@ -43,5 +43,5 @@ def ui():
picture_select.upload(
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
lambda: None, None, picture_select, show_progress=False)

View file

@ -29,7 +29,6 @@ current_params = params.copy()
voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115']
voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high']
voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast']
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
# Used for making text xml compatible, needed for voice pitch and speed control
table = str.maketrans({
@ -76,6 +75,11 @@ def toggle_text_in_history(name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def state_modifier(state):
state['stream'] = False
return state
def input_modifier(string):
"""
This function is applied to your text inputs before
@ -87,7 +91,6 @@ def input_modifier(string):
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')]
shared.processing_message = "*Is recording a voice message...*"
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
return string
@ -124,7 +127,6 @@ def output_modifier(string):
string += f'\n\n{original_string}'
shared.processing_message = "*Is typing...*"
shared.args.no_stream = streaming_state # restore the streaming option to the previous value
return string

View file

@ -86,6 +86,15 @@ def _apply_custom_generate_chat_prompt(text, state, **kwargs):
return None
# Extension that modifies the input parameters before they are used
def _apply_state_modifier_extensions(state):
for extension, _ in iterator():
if hasattr(extension, "state_modifier"):
state = getattr(extension, "state_modifier")(state)
return state
# Extension functions that override the default tokenizer output
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
@ -95,13 +104,24 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e
return prompt, input_ids, input_embeds
# Custom generate reply handling
def _apply_custom_generate_reply():
for extension, _ in iterator():
if hasattr(extension, 'custom_generate_reply'):
return getattr(extension, 'custom_generate_reply')
return None
EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"state": _apply_state_modifier_extensions,
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
"custom_generate_reply": _apply_custom_generate_reply
}

View file

@ -21,6 +21,7 @@ def get_max_prompt_length(state):
max_length = state['truncation_length'] - state['max_new_tokens']
if shared.soft_prompt:
max_length -= shared.soft_prompt_tensor.shape[1]
return max_length
@ -62,6 +63,36 @@ def decode(output_ids, skip_special_tokens=True):
return shared.tokenizer.decode(output_ids, skip_special_tokens)
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s):
for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
s = re.sub("--- [0-9]*\n *\n---", "---", s)
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s
# Fix the LaTeX equations in galactica
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
s = s.replace(r'\(', r'$')
s = s.replace(r'\)', r'$')
s = s.replace(r'$$', r'$')
s = re.sub(r'\n', r'\n\n', s)
s = re.sub(r"\n{3,}", "\n\n", s)
return s
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
if shared.model_type == 'HF_seq2seq':
reply = decode(output_ids, state['skip_special_tokens'])
@ -81,35 +112,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
return reply
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s):
for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
s = re.sub("--- [0-9]*\n *\n---", "---", s)
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s
# Fix the LaTeX equations in galactica
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
s = s.replace(r'\(', r'$')
s = s.replace(r'\)', r'$')
s = s.replace(r'$$', r'$')
s = re.sub(r'\n', r'\n\n', s)
s = re.sub(r"\n{3,}", "\n\n", s)
return s
def formatted_outputs(reply, model_name):
if not shared.is_chat():
if shared.model_type == 'galactica':
@ -140,51 +142,21 @@ def stop_everything_event():
shared.stop_everything = True
def get_generate_params(state):
generate_params = {}
# Models that are not on transformers
if shared.model_type in ['rwkv', 'llamacpp']:
generate_params['token_count'] = state['max_new_tokens']
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
else:
# FlexGen
if shared.args.flexgen:
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = state[k]
if not shared.args.no_stream:
generate_params['max_new_tokens'] = 8
# transformers
else:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
if shared.args.no_cache:
generate_params.update({'use_cache': False})
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})
return generate_params
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.")
yield formatted_outputs(question, shared.model_name)
return
state = apply_extensions('state', state)
generate_func = apply_extensions('custom_generate_reply')
if generate_func is None:
if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.")
yield formatted_outputs(question, shared.model_name)
return
clear_torch_cache()
seed = set_manual_seed(state['seed'])
shared.stop_everything = False
generate_params = get_generate_params(state)
t0 = time.time()
if shared.model_type in ['rwkv', 'llamacpp']:
generate_func = generate_reply_custom
elif shared.args.flexgen:
generate_func = generate_reply_flexgen
else:
generate_func = generate_reply_HF
# Preparing the input
original_question = question
@ -194,42 +166,31 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.args.verbose:
print(f'\n\n{question}\n--------------------\n')
# If the model is not on transformers, handle it separately and end this
# function call earlier.
if shared.model_type in ['rwkv', 'llamacpp']:
shared.stop_everything = False
clear_torch_cache()
seed = set_manual_seed(state['seed'])
for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings):
yield formatted_outputs(reply, shared.model_name)
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
else:
if not shared.is_chat():
yield formatted_outputs(question, shared.model_name)
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
yield formatted_outputs(reply, shared.model_name)
if shared.args.no_cache:
generate_params.update({'use_cache': False})
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})
# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
cuda = not any((shared.args.cpu, shared.args.deepspeed))
# Find the eos tokens
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
@ -259,15 +220,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
break
# Update generate_params with the eos token and the stopping strings
if shared.args.flexgen:
generate_params['stop'] = eos_token_ids[-1]
else:
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list
t0 = time.time()
try:
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
yield original_question
# Generate the entire reply at once.
if shared.args.no_stream:
if not state['stream']:
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if cuda:
@ -276,12 +238,11 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
yield formatted_outputs(reply, shared.model_name)
yield get_reply_from_output_ids(output, input_ids, original_question, state)
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.
elif not shared.args.flexgen:
else:
def generate_with_callback(callback=None, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
@ -292,45 +253,118 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
yield formatted_outputs(original_question, shared.model_name)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
yield get_reply_from_output_ids(output, input_ids, original_question, state)
if output[-1] in eos_token_ids:
break
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(state['max_new_tokens'] // 8 + 1):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
generate_params.update({'inputs': input_ids})
yield formatted_outputs(reply, shared.model_name)
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(original_input_ids[0])
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
seed = set_manual_seed(state['seed'])
generate_params = {'token_count': state['max_new_tokens']}
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
t0 = time.time()
try:
if not shared.is_chat():
yield question
if not state['stream']:
reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield reply
else:
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield reply
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = state[k]
if state['stream']:
generate_params['max_new_tokens'] = 8
# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
# Find the eos tokens
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Add the encoded tokens to generate_params
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
# Update generate_params with the eos token and the stopping strings
generate_params['stop'] = eos_token_ids[-1]
t0 = time.time()
try:
if not shared.is_chat():
yield question
# Generate the entire reply at once.
if not state['stream']:
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
yield get_reply_from_output_ids(output, input_ids, original_question, state)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(state['max_new_tokens'] // 8 + 1):
if shared.stop_everything:
break
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
input_ids = np.reshape(output, (1, output.shape[0]))
generate_params.update({'inputs': input_ids})
except Exception:
traceback.print_exc()
finally:

View file

@ -34,7 +34,7 @@ def list_model_elements():
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'turn_template', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']

View file

@ -15,6 +15,7 @@ def my_get(url, **kwargs):
kwargs.setdefault('allow_redirects', True)
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
original_get = requests.get
requests.get = my_get
import gradio as gr
@ -454,6 +455,7 @@ def create_settings_menus(default_preset):
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
with gr.Accordion('Soft prompt', open=False):
with gr.Row():
@ -721,46 +723,46 @@ def create_interface():
gen_events.append(shared.gradio['Generate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['textbox'].submit(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Regenerate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Continue'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Impersonate'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False)
)
shared.gradio['Replace last reply'].click(
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Send dummy message'].click(
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Send dummy reply'].click(
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
@ -786,7 +788,7 @@ def create_interface():
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=False)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
@ -808,14 +810,14 @@ def create_interface():
gen_events.append(shared.gradio['Generate'].click(
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
)
gen_events.append(shared.gradio['textbox'].submit(
lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
)
@ -824,13 +826,13 @@ def create_interface():
gen_events.append(shared.gradio['Regenerate'].click(
lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
generate_reply, shared.input_params, output_params, show_progress=False) # .then(
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
)
else:
gen_events.append(shared.gradio['Continue'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream) # .then(
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False) # .then(
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
)