From 8aafb1f7960fab586fd9645d162574112d91b590 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 5 May 2023 18:53:03 -0300 Subject: [PATCH] Refactor text_generation.py, add support for custom generation functions (#1817) --- docs/Extensions.md | 72 +++++-- extensions/api/blocking_api.py | 3 +- extensions/api/streaming_api.py | 1 + extensions/elevenlabs_tts/script.py | 16 +- extensions/openai/script.py | 7 +- extensions/sd_api_pictures/script.py | 11 +- extensions/send_pictures/script.py | 2 +- extensions/silero_tts/script.py | 8 +- modules/extensions.py | 22 +- modules/text_generation.py | 312 +++++++++++++++------------ modules/ui.py | 2 +- server.py | 28 +-- 12 files changed, 289 insertions(+), 195 deletions(-) diff --git a/docs/Extensions.md b/docs/Extensions.md index dd4af96d..bc78dd0d 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -45,7 +45,9 @@ Most of these have been created by the extremely talented contributors that you | `def ui()` | Creates custom gradio elements when the UI is launched. | | `def input_modifier(string)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | +| `def state_modifier(state)` | Modifies the dictionary containing the input parameters before it is used by the text generation functions. | | `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). | +| `def custom_generate_reply(...)` | Overrides the main text generation function. | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | | `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example | @@ -104,6 +106,23 @@ python server.py --extensions enthusiasm translate # First apply enthusiasm, the python server.py --extensions translate enthusiasm # First apply translate, then enthusiasm ``` +## `custom_generate_reply` example + +Once defined in a `script.py`, this function is executed in place of the main generation functions. You can use it to connect the web UI to an external API, or to load a custom model that is not supported yet. + +```python +import datetime + +def custom_generate_reply(question, original_question, seed, state, eos_token, stopping_strings): + cumulative = '' + for i in range(10): + cumulative += f"Counting: {i}...\n" + yield cumulative + + cumulative += f"Done! {str(datetime.datetime.now())}" + yield cumulative +``` + ## `custom_generate_chat_prompt` example Below is an extension that just reproduces the default prompt generator in `modules/chat.py`. You can modify it freely to come up with your own prompts in chat mode. @@ -114,51 +133,64 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): _continue = kwargs['_continue'] if '_continue' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False is_instruct = state['mode'] == 'instruct' - rows = [f"{state['context'].strip()}\n"] + rows = [state['context'] if is_instruct else f"{state['context'].strip()}\n"] + min_rows = 3 # Finding the maximum prompt size chat_prompt_size = state['chat_prompt_size'] if shared.soft_prompt: chat_prompt_size -= shared.soft_prompt_tensor.shape[1] + max_length = min(get_max_prompt_length(state), chat_prompt_size) - if is_instruct: - prefix1 = f"{state['name1']}\n" - prefix2 = f"{state['name2']}\n" + # Building the turn templates + if 'turn_template' not in state or state['turn_template'] == '': + if is_instruct: + template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n' + else: + template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n' else: - prefix1 = f"{state['name1']}: " - prefix2 = f"{state['name2']}: " + template = state['turn_template'].replace(r'\n', '\n') + replacements = { + '<|user|>': state['name1'].strip(), + '<|bot|>': state['name2'].strip(), + } + + user_turn = replace_all(template.split('<|bot|>')[0], replacements) + bot_turn = replace_all('<|bot|>' + template.split('<|bot|>')[1], replacements) + user_turn_stripped = replace_all(user_turn.split('<|user-message|>')[0], replacements) + bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements) + + # Building the prompt i = len(shared.history['internal']) - 1 while i >= 0 and len(encode(''.join(rows))[0]) < max_length: if _continue and i == len(shared.history['internal']) - 1: - rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}") + rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip()) else: - rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n") + rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip())) + string = shared.history['internal'][i][0] if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: - rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n") + rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)})) + i -= 1 if impersonate: - rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") - limit = 2 - elif _continue: - limit = 3 - else: + min_rows = 2 + rows.append(user_turn_stripped.rstrip(' ')) + elif not _continue: # Adding the user message - user_input = fix_newlines(user_input) if len(user_input) > 0: - rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n") + rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))})) # Adding the Character prefix - rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) - limit = 3 + rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' '))) - while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length: + while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length: rows.pop(1) - prompt = ''.join(rows) + prompt = ''.join(rows) if also_return_rows: return prompt, rows else: diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index e66a6a50..2c72d789 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -33,6 +33,7 @@ class Handler(BaseHTTPRequestHandler): prompt = body['prompt'] generate_params = build_parameters(body) stopping_strings = generate_params.pop('stopping_strings') + generate_params['stream'] = False generator = generate_reply( prompt, generate_params, stopping_strings=stopping_strings) @@ -66,7 +67,7 @@ class Handler(BaseHTTPRequestHandler): self.send_error(404) -def _run_server(port: int, share: bool=False): +def _run_server(port: int, share: bool = False): address = '0.0.0.0' if shared.args.listen else '127.0.0.1' server = ThreadingHTTPServer((address, port), Handler) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 3b9ac658..42570c94 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -23,6 +23,7 @@ async def _handle_connection(websocket, path): prompt = message['prompt'] generate_params = build_parameters(message) stopping_strings = generate_params.pop('stopping_strings') + generate_params['stream'] = True generator = generate_reply( prompt, generate_params, stopping_strings=stopping_strings) diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py index 5c727a30..86ea9a54 100644 --- a/extensions/elevenlabs_tts/script.py +++ b/extensions/elevenlabs_tts/script.py @@ -18,15 +18,8 @@ wav_idx = 0 user = ElevenLabsUser(params['api_key']) user_info = None -if not shared.args.no_stream: - print("Please add --no-stream. This extension is not meant to be used with streaming.") - raise ValueError - # Check if the API is valid and refresh the UI accordingly. - - def check_valid_api(): - global user, user_info, params user = ElevenLabsUser(params['api_key']) @@ -41,9 +34,8 @@ def check_valid_api(): print('Got an API Key!') return gr.update(value='Connected') + # Once the API is verified, get the available voices and update the dropdown list - - def refresh_voices(): global user, user_info @@ -63,6 +55,11 @@ def remove_surrounded_chars(string): return re.sub('\*[^\*]*?(\*|$)', '', string) +def state_modifier(state): + state['stream'] = False + return state + + def input_modifier(string): """ This function is applied to your text inputs before @@ -109,6 +106,7 @@ def ui(): with gr.Row(): activate = gr.Checkbox(value=params['activate'], label='Activate TTS') connection_status = gr.Textbox(value='Disconnected', label='Connection Status') + voice = gr.Dropdown(value=params['selected_voice'], choices=initial_voice, label='TTS Voice') with gr.Row(): api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key') diff --git a/extensions/openai/script.py b/extensions/openai/script.py index f9373385..c168ec95 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -266,8 +266,7 @@ class Handler(BaseHTTPRequestHandler): stopping_strings += standard_stopping_strings req_params['custom_stopping_strings'] = stopping_strings - shared.args.no_stream = not req_params['stream'] - if not shared.args.no_stream: + if req_params['stream']: shared.args.chat = True # begin streaming chunk = { @@ -337,7 +336,7 @@ class Handler(BaseHTTPRequestHandler): if buffer_and_continue: continue - if not shared.args.no_stream: + if req_params['stream']: # Streaming new_content = answer[len_seen:] @@ -365,7 +364,7 @@ class Handler(BaseHTTPRequestHandler): self.wfile.write(response.encode('utf-8')) completion_token_count += len(encode(new_content)[0]) - if not shared.args.no_stream: + if req_params['stream']: chunk = { "id": cmpl_id, "object": stream_object_type, diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py index 1189a593..2d4e39dc 100644 --- a/extensions/sd_api_pictures/script.py +++ b/extensions/sd_api_pictures/script.py @@ -75,7 +75,6 @@ if params['manage_VRAM']: samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select -streaming_state = shared.args.no_stream # remember if chat streaming was enabled picture_response = False # specifies if the next model response should appear as a picture def remove_surrounded_chars(string): @@ -92,6 +91,13 @@ def triggers_are_in(string): return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string)) +def state_modifier(state): + if picture_response: + state['stream'] = False + + return state + + def input_modifier(string): """ This function is applied to your text inputs before @@ -218,14 +224,13 @@ def bot_prefix_modifier(string): def toggle_generation(*args): - global picture_response, shared, streaming_state + global picture_response, shared if not args: picture_response = not picture_response else: picture_response = args[0] - shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*" diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index b21423e4..55bd866c 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -43,5 +43,5 @@ def ui(): picture_select.upload( lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then( gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( + chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( lambda: None, None, picture_select, show_progress=False) diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index 460e76a8..7d762e39 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -29,7 +29,6 @@ current_params = params.copy() voices_by_gender = ['en_99', 'en_45', 'en_18', 'en_117', 'en_49', 'en_51', 'en_68', 'en_0', 'en_26', 'en_56', 'en_74', 'en_5', 'en_38', 'en_53', 'en_21', 'en_37', 'en_107', 'en_10', 'en_82', 'en_16', 'en_41', 'en_12', 'en_67', 'en_61', 'en_14', 'en_11', 'en_39', 'en_52', 'en_24', 'en_97', 'en_28', 'en_72', 'en_94', 'en_36', 'en_4', 'en_43', 'en_88', 'en_25', 'en_65', 'en_6', 'en_44', 'en_75', 'en_91', 'en_60', 'en_109', 'en_85', 'en_101', 'en_108', 'en_50', 'en_96', 'en_64', 'en_92', 'en_76', 'en_33', 'en_116', 'en_48', 'en_98', 'en_86', 'en_62', 'en_54', 'en_95', 'en_55', 'en_111', 'en_3', 'en_83', 'en_8', 'en_47', 'en_59', 'en_1', 'en_2', 'en_7', 'en_9', 'en_13', 'en_15', 'en_17', 'en_19', 'en_20', 'en_22', 'en_23', 'en_27', 'en_29', 'en_30', 'en_31', 'en_32', 'en_34', 'en_35', 'en_40', 'en_42', 'en_46', 'en_57', 'en_58', 'en_63', 'en_66', 'en_69', 'en_70', 'en_71', 'en_73', 'en_77', 'en_78', 'en_79', 'en_80', 'en_81', 'en_84', 'en_87', 'en_89', 'en_90', 'en_93', 'en_100', 'en_102', 'en_103', 'en_104', 'en_105', 'en_106', 'en_110', 'en_112', 'en_113', 'en_114', 'en_115'] voice_pitches = ['x-low', 'low', 'medium', 'high', 'x-high'] voice_speeds = ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] -streaming_state = shared.args.no_stream # remember if chat streaming was enabled # Used for making text xml compatible, needed for voice pitch and speed control table = str.maketrans({ @@ -76,6 +75,11 @@ def toggle_text_in_history(name1, name2, mode): return chat_html_wrapper(shared.history['visible'], name1, name2, mode) +def state_modifier(state): + state['stream'] = False + return state + + def input_modifier(string): """ This function is applied to your text inputs before @@ -87,7 +91,6 @@ def input_modifier(string): shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>', 'controls>')] shared.processing_message = "*Is recording a voice message...*" - shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated return string @@ -124,7 +127,6 @@ def output_modifier(string): string += f'\n\n{original_string}' shared.processing_message = "*Is typing...*" - shared.args.no_stream = streaming_state # restore the streaming option to the previous value return string diff --git a/modules/extensions.py b/modules/extensions.py index 1bb36d52..7d50055d 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -86,6 +86,15 @@ def _apply_custom_generate_chat_prompt(text, state, **kwargs): return None +# Extension that modifies the input parameters before they are used +def _apply_state_modifier_extensions(state): + for extension, _ in iterator(): + if hasattr(extension, "state_modifier"): + state = getattr(extension, "state_modifier")(state) + + return state + + # Extension functions that override the default tokenizer output def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): for extension, _ in iterator(): @@ -95,13 +104,24 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e return prompt, input_ids, input_embeds +# Custom generate reply handling +def _apply_custom_generate_reply(): + for extension, _ in iterator(): + if hasattr(extension, 'custom_generate_reply'): + return getattr(extension, 'custom_generate_reply') + + return None + + EXTENSION_MAP = { "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), + "state": _apply_state_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), "input_hijack": _apply_input_hijack, - "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt + "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, + "custom_generate_reply": _apply_custom_generate_reply } diff --git a/modules/text_generation.py b/modules/text_generation.py index ae6cf8be..ba3f16b9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -21,6 +21,7 @@ def get_max_prompt_length(state): max_length = state['truncation_length'] - state['max_new_tokens'] if shared.soft_prompt: max_length -= shared.soft_prompt_tensor.shape[1] + return max_length @@ -62,6 +63,36 @@ def decode(output_ids, skip_special_tokens=True): return shared.tokenizer.decode(output_ids, skip_special_tokens) +def generate_softprompt_input_tensors(input_ids): + inputs_embeds = shared.model.transformer.wte(input_ids) + inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1) + filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device) + # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens + return inputs_embeds, filler_input_ids + + +# Removes empty replies from gpt4chan outputs +def fix_gpt4chan(s): + for i in range(10): + s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) + s = re.sub("--- [0-9]*\n *\n---", "---", s) + s = re.sub("--- [0-9]*\n\n\n---", "---", s) + + return s + + +# Fix the LaTeX equations in galactica +def fix_galactica(s): + s = s.replace(r'\[', r'$') + s = s.replace(r'\]', r'$') + s = s.replace(r'\(', r'$') + s = s.replace(r'\)', r'$') + s = s.replace(r'$$', r'$') + s = re.sub(r'\n', r'\n\n', s) + s = re.sub(r"\n{3,}", "\n\n", s) + return s + + def get_reply_from_output_ids(output_ids, input_ids, original_question, state): if shared.model_type == 'HF_seq2seq': reply = decode(output_ids, state['skip_special_tokens']) @@ -81,35 +112,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state): return reply -def generate_softprompt_input_tensors(input_ids): - inputs_embeds = shared.model.transformer.wte(input_ids) - inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1) - filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device) - # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens - return inputs_embeds, filler_input_ids - - -# Removes empty replies from gpt4chan outputs -def fix_gpt4chan(s): - for i in range(10): - s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) - s = re.sub("--- [0-9]*\n *\n---", "---", s) - s = re.sub("--- [0-9]*\n\n\n---", "---", s) - return s - - -# Fix the LaTeX equations in galactica -def fix_galactica(s): - s = s.replace(r'\[', r'$') - s = s.replace(r'\]', r'$') - s = s.replace(r'\(', r'$') - s = s.replace(r'\)', r'$') - s = s.replace(r'$$', r'$') - s = re.sub(r'\n', r'\n\n', s) - s = re.sub(r"\n{3,}", "\n\n", s) - return s - - def formatted_outputs(reply, model_name): if not shared.is_chat(): if shared.model_type == 'galactica': @@ -140,51 +142,21 @@ def stop_everything_event(): shared.stop_everything = True -def get_generate_params(state): - generate_params = {} - - # Models that are not on transformers - if shared.model_type in ['rwkv', 'llamacpp']: - generate_params['token_count'] = state['max_new_tokens'] - for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: - generate_params[k] = state[k] - else: - # FlexGen - if shared.args.flexgen: - for k in ['max_new_tokens', 'do_sample', 'temperature']: - generate_params[k] = state[k] - - if not shared.args.no_stream: - generate_params['max_new_tokens'] = 8 - - # transformers - else: - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: - generate_params[k] = state[k] - - if state['ban_eos_token']: - generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] - - if shared.args.no_cache: - generate_params.update({'use_cache': False}) - - if shared.args.deepspeed: - generate_params.update({'synced_gpus': True}) - - return generate_params - - def generate_reply(question, state, eos_token=None, stopping_strings=[]): - if shared.model_name == 'None' or shared.model is None: - logging.error("No model is loaded! Select one in the Model tab.") - yield formatted_outputs(question, shared.model_name) - return + state = apply_extensions('state', state) + generate_func = apply_extensions('custom_generate_reply') + if generate_func is None: + if shared.model_name == 'None' or shared.model is None: + logging.error("No model is loaded! Select one in the Model tab.") + yield formatted_outputs(question, shared.model_name) + return - clear_torch_cache() - seed = set_manual_seed(state['seed']) - shared.stop_everything = False - generate_params = get_generate_params(state) - t0 = time.time() + if shared.model_type in ['rwkv', 'llamacpp']: + generate_func = generate_reply_custom + elif shared.args.flexgen: + generate_func = generate_reply_flexgen + else: + generate_func = generate_reply_HF # Preparing the input original_question = question @@ -194,42 +166,31 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): if shared.args.verbose: print(f'\n\n{question}\n--------------------\n') - # If the model is not on transformers, handle it separately and end this - # function call earlier. - if shared.model_type in ['rwkv', 'llamacpp']: + shared.stop_everything = False + clear_torch_cache() + seed = set_manual_seed(state['seed']) + for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings): + yield formatted_outputs(reply, shared.model_name) - try: - if shared.args.no_stream: - reply = shared.model.generate(context=question, **generate_params) - output = original_question + reply - if not shared.is_chat(): - reply = original_question + apply_extensions('output', reply) - yield formatted_outputs(reply, shared.model_name) - else: - if not shared.is_chat(): - yield formatted_outputs(question, shared.model_name) +def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=[]): + generate_params = {} + for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: + generate_params[k] = state[k] - for reply in shared.model.generate_with_streaming(context=question, **generate_params): - output = original_question + reply - if not shared.is_chat(): - reply = original_question + apply_extensions('output', reply) + if state['ban_eos_token']: + generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] - yield formatted_outputs(reply, shared.model_name) + if shared.args.no_cache: + generate_params.update({'use_cache': False}) - except Exception: - traceback.print_exc() - finally: - t1 = time.time() - original_tokens = len(encode(original_question)[0]) - new_tokens = len(encode(output)[0]) - original_tokens - print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') - return + if shared.args.deepspeed: + generate_params.update({'synced_gpus': True}) # Encode the input input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) output = input_ids[0] - cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) + cuda = not any((shared.args.cpu, shared.args.deepspeed)) # Find the eos tokens eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] @@ -259,15 +220,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): break # Update generate_params with the eos token and the stopping strings - if shared.args.flexgen: - generate_params['stop'] = eos_token_ids[-1] - else: - generate_params['eos_token_id'] = eos_token_ids - generate_params['stopping_criteria'] = stopping_criteria_list + generate_params['eos_token_id'] = eos_token_ids + generate_params['stopping_criteria'] = stopping_criteria_list + t0 = time.time() try: + if not shared.is_chat() and shared.model_type != 'HF_seq2seq': + yield original_question + # Generate the entire reply at once. - if shared.args.no_stream: + if not state['stream']: with torch.no_grad(): output = shared.model.generate(**generate_params)[0] if cuda: @@ -276,12 +238,11 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - reply = get_reply_from_output_ids(output, input_ids, original_question, state) - yield formatted_outputs(reply, shared.model_name) + yield get_reply_from_output_ids(output, input_ids, original_question, state) # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator. - elif not shared.args.flexgen: + else: def generate_with_callback(callback=None, **kwargs): kwargs['stopping_criteria'].append(Stream(callback_func=callback)) @@ -292,45 +253,118 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) - if not shared.is_chat() and shared.model_type != 'HF_seq2seq': - yield formatted_outputs(original_question, shared.model_name) - with generate_with_streaming(**generate_params) as generator: for output in generator: if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - reply = get_reply_from_output_ids(output, input_ids, original_question, state) + yield get_reply_from_output_ids(output, input_ids, original_question, state) if output[-1] in eos_token_ids: break - yield formatted_outputs(reply, shared.model_name) - - # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' - else: - for i in range(state['max_new_tokens'] // 8 + 1): - clear_torch_cache() - with torch.no_grad(): - output = shared.model.generate(**generate_params)[0] - - if shared.soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - - reply = get_reply_from_output_ids(output, input_ids, original_question, state) - if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): - break - - yield formatted_outputs(reply, shared.model_name) - input_ids = np.reshape(output, (1, output.shape[0])) - if shared.soft_prompt: - inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - generate_params.update({'inputs_embeds': inputs_embeds}) - generate_params.update({'inputs': filler_input_ids}) - else: - generate_params.update({'inputs': input_ids}) - - yield formatted_outputs(reply, shared.model_name) - + except Exception: + traceback.print_exc() + finally: + t1 = time.time() + original_tokens = len(original_input_ids[0]) + new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0) + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') + return + + +def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=[]): + seed = set_manual_seed(state['seed']) + generate_params = {'token_count': state['max_new_tokens']} + for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: + generate_params[k] = state[k] + + t0 = time.time() + try: + if not shared.is_chat(): + yield question + + if not state['stream']: + reply = shared.model.generate(context=question, **generate_params) + output = original_question + reply + if not shared.is_chat(): + reply = original_question + apply_extensions('output', reply) + + yield reply + else: + + for reply in shared.model.generate_with_streaming(context=question, **generate_params): + output = original_question + reply + if not shared.is_chat(): + reply = original_question + apply_extensions('output', reply) + + yield reply + + except Exception: + traceback.print_exc() + finally: + t1 = time.time() + original_tokens = len(encode(original_question)[0]) + new_tokens = len(encode(output)[0]) - original_tokens + print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') + return + + +def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=[]): + generate_params = {} + for k in ['max_new_tokens', 'do_sample', 'temperature']: + generate_params[k] = state[k] + + if state['stream']: + generate_params['max_new_tokens'] = 8 + + # Encode the input + input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) + output = input_ids[0] + + # Find the eos tokens + eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] + if eos_token is not None: + eos_token_ids.append(int(encode(eos_token)[0][-1])) + + # Add the encoded tokens to generate_params + question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) + original_input_ids = input_ids + generate_params.update({'inputs': input_ids}) + if inputs_embeds is not None: + generate_params.update({'inputs_embeds': inputs_embeds}) + + # Update generate_params with the eos token and the stopping strings + generate_params['stop'] = eos_token_ids[-1] + + t0 = time.time() + try: + if not shared.is_chat(): + yield question + + # Generate the entire reply at once. + if not state['stream']: + with torch.no_grad(): + output = shared.model.generate(**generate_params)[0] + + yield get_reply_from_output_ids(output, input_ids, original_question, state) + + # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' + else: + for i in range(state['max_new_tokens'] // 8 + 1): + if shared.stop_everything: + break + + clear_torch_cache() + with torch.no_grad(): + output = shared.model.generate(**generate_params)[0] + + if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): + break + + yield get_reply_from_output_ids(output, original_input_ids, original_question, state) + input_ids = np.reshape(output, (1, output.shape[0])) + generate_params.update({'inputs': input_ids}) + except Exception: traceback.print_exc() finally: diff --git a/modules/ui.py b/modules/ui.py index 53f6a247..28ff2f53 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -34,7 +34,7 @@ def list_model_elements(): def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu'] + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream'] if chat: elements += ['name1', 'name2', 'greeting', 'context', 'turn_template', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu'] diff --git a/server.py b/server.py index 6aa69c90..603c82fc 100644 --- a/server.py +++ b/server.py @@ -15,6 +15,7 @@ def my_get(url, **kwargs): kwargs.setdefault('allow_redirects', True) return requests.api.request('get', 'http://127.0.0.1/', **kwargs) + original_get = requests.get requests.get = my_get import gradio as gr @@ -454,6 +455,7 @@ def create_settings_menus(default_preset): shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') + shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') with gr.Accordion('Soft prompt', open=False): with gr.Row(): @@ -721,46 +723,46 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( - chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( + chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) ) gen_events.append(shared.gradio['textbox'].submit( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( - chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( + chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) ) gen_events.append(shared.gradio['Regenerate'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( + chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) ) gen_events.append(shared.gradio['Continue'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( + chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) ) gen_events.append(shared.gradio['Impersonate'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream) + chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False) ) shared.gradio['Replace last reply'].click( - chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then( + chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then( lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) shared.gradio['Send dummy message'].click( - chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then( + chat.send_dummy_message, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then( lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) shared.gradio['Send dummy reply'].click( - chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then( + chat.send_dummy_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=False).then( lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( chat.save_history, shared.gradio['mode'], None, show_progress=False) @@ -786,7 +788,7 @@ def create_interface(): chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then( chat.redraw_html, reload_inputs, shared.gradio['display']) - shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream) + shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=False) shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) @@ -808,14 +810,14 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click( lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then( + generate_reply, shared.input_params, output_params, show_progress=False) # .then( # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") ) gen_events.append(shared.gradio['textbox'].submit( lambda x: x, shared.gradio['textbox'], shared.gradio['last_input']).then( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then( + generate_reply, shared.input_params, output_params, show_progress=False) # .then( # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") ) @@ -824,13 +826,13 @@ def create_interface(): gen_events.append(shared.gradio['Regenerate'].click( lambda x: x, shared.gradio['last_input'], shared.gradio['textbox'], show_progress=False).then( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then( + generate_reply, shared.input_params, output_params, show_progress=False) # .then( # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}") ) else: gen_events.append(shared.gradio['Continue'].click( ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then( - generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream) # .then( + generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=False) # .then( # None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}") )