From ea6e77df726bc57d60cd77be399dbfd50c25345e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 7 Apr 2023 00:15:45 -0300 Subject: [PATCH] Make the code more like PEP8 for readability (#862) --- api-example-stream.py | 6 ++- convert-to-flexgen.py | 9 ++-- convert-to-safetensors.py | 2 +- download-model.py | 16 ++++-- extensions/api/script.py | 28 +++++----- extensions/character_bias/script.py | 8 ++- extensions/elevenlabs_tts/script.py | 34 ++++++++----- extensions/gallery/script.py | 4 +- extensions/google_translate/script.py | 6 ++- extensions/llama_prompts/script.py | 2 + extensions/sd_api_pictures/script.py | 38 +++++++++----- extensions/send_pictures/script.py | 5 +- modules/GPTQ_loader.py | 19 ++++--- modules/LoRA.py | 5 +- modules/RWKV.py | 17 ++++--- modules/api.py | 1 + modules/callbacks.py | 9 ++-- modules/chat.py | 57 ++++++++++++++------- modules/extensions.py | 8 ++- modules/html_generator.py | 26 +++++++--- modules/llamacpp_model.py | 4 +- modules/llamacpp_model_alternative.py | 6 +-- modules/models.py | 21 ++++---- modules/shared.py | 5 +- modules/text_generation.py | 26 +++++++--- modules/training.py | 30 +++++++---- modules/ui.py | 2 + server.py | 73 +++++++++++++++++---------- 28 files changed, 302 insertions(+), 165 deletions(-) diff --git a/api-example-stream.py b/api-example-stream.py index 32eefc7e..17de4c28 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -17,6 +17,7 @@ def random_hash(): letters = string.ascii_lowercase + string.digits return ''.join(random.choice(letters) for i in range(9)) + async def run(context): server = "127.0.0.1" params = { @@ -41,7 +42,7 @@ async def run(context): async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: while content := json.loads(await websocket.recv()): - #Python3.10 syntax, replace with if elif on older + # Python3.10 syntax, replace with if elif on older match content["msg"]: case "send_hash": await websocket.send(json.dumps({ @@ -62,13 +63,14 @@ async def run(context): pass case "process_generating" | "process_completed": yield content["output"]["data"][0] - # You can search for your desired end indicator and + # You can search for your desired end indicator and # stop generation by closing the websocket here if (content["msg"] == "process_completed"): break prompt = "What I would like to say is the following: " + async def get_result(): async for response in run(prompt): # Print intermediate steps diff --git a/convert-to-flexgen.py b/convert-to-flexgen.py index 917f023c..7654593b 100644 --- a/convert-to-flexgen.py +++ b/convert-to-flexgen.py @@ -13,10 +13,11 @@ import torch from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer -parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) +parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") args = parser.parse_args() + def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. @@ -31,20 +32,22 @@ def disable_torch_init(): torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + def restore_torch_init(): """Rollback the change made by disable_torch_init.""" import torch setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) + if __name__ == '__main__': path = Path(args.MODEL) model_name = path.name print(f"Loading {model_name}...") - #disable_torch_init() + # disable_torch_init() model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True) - #restore_torch_init() + # restore_torch_init() tokenizer = AutoTokenizer.from_pretrained(path) diff --git a/convert-to-safetensors.py b/convert-to-safetensors.py index 63baaa97..3b721e7c 100644 --- a/convert-to-safetensors.py +++ b/convert-to-safetensors.py @@ -17,7 +17,7 @@ from pathlib import Path import torch from transformers import AutoModelForCausalLM, AutoTokenizer -parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) +parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).') parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).") diff --git a/download-model.py b/download-model.py index db95c4b5..38e5f452 100644 --- a/download-model.py +++ b/download-model.py @@ -29,6 +29,7 @@ parser.add_argument('--clean', action='store_true', help='Does not resume the pr parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') args = parser.parse_args() + def get_file(url, output_folder): filename = Path(url.rsplit('/', 1)[1]) output_path = output_folder / filename @@ -54,6 +55,7 @@ def get_file(url, output_folder): t.update(len(data)) f.write(data) + def sanitize_branch_name(branch_name): pattern = re.compile(r"^[a-zA-Z0-9._-]+$") if pattern.match(branch_name): @@ -61,6 +63,7 @@ def sanitize_branch_name(branch_name): else: raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") + def select_model_from_default_options(): models = { "OPT 6.7B": ("facebook", "opt-6.7b", "main"), @@ -78,11 +81,11 @@ def select_model_from_default_options(): choices = {} print("Select the model that you want to download:\n") - for i,name in enumerate(models): - char = chr(ord('A')+i) + for i, name in enumerate(models): + char = chr(ord('A') + i) choices[char] = name print(f"{char}) {name}") - char = chr(ord('A')+len(models)) + char = chr(ord('A') + len(models)) print(f"{char}) None of the above") print() @@ -106,6 +109,7 @@ EleutherAI/pythia-1.4b-deduped return model, branch + def get_download_links_from_huggingface(model, branch): base = "https://huggingface.co" page = f"/api/models/{model}/tree/{branch}?cursor=" @@ -166,15 +170,17 @@ def get_download_links_from_huggingface(model, branch): # If both pytorch and safetensors are available, download safetensors only if (has_pytorch or has_pt) and has_safetensors: - for i in range(len(classifications)-1, -1, -1): + for i in range(len(classifications) - 1, -1, -1): if classifications[i] in ['pytorch', 'pt']: links.pop(i) return links, sha256, is_lora + def download_files(file_list, output_folder, num_threads=8): thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True) + if __name__ == '__main__': model = args.MODEL branch = args.branch @@ -224,7 +230,7 @@ if __name__ == '__main__': validated = False else: print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}') - + if validated: print('[+] Validated checksums of all model files!') else: diff --git a/extensions/api/script.py b/extensions/api/script.py index 6726d61d..4981725f 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -9,6 +9,7 @@ params = { 'port': 5000, } + class Handler(BaseHTTPRequestHandler): def do_GET(self): if self.path == '/api/v1/model': @@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler): self.end_headers() prompt = body['prompt'] - prompt_lines = [l.strip() for l in prompt.split('\n')] + prompt_lines = [k.strip() for k in prompt.split('\n')] max_context = body.get('max_context_length', 2048) @@ -40,18 +41,18 @@ class Handler(BaseHTTPRequestHandler): prompt_lines.pop(0) prompt = '\n'.join(prompt_lines) - generate_params = { - 'max_new_tokens': int(body.get('max_length', 200)), + generate_params = { + 'max_new_tokens': int(body.get('max_length', 200)), 'do_sample': bool(body.get('do_sample', True)), - 'temperature': float(body.get('temperature', 0.5)), - 'top_p': float(body.get('top_p', 1)), - 'typical_p': float(body.get('typical', 1)), - 'repetition_penalty': float(body.get('rep_pen', 1.1)), + 'temperature': float(body.get('temperature', 0.5)), + 'top_p': float(body.get('top_p', 1)), + 'typical_p': float(body.get('typical', 1)), + 'repetition_penalty': float(body.get('rep_pen', 1.1)), 'encoder_repetition_penalty': 1, - 'top_k': int(body.get('top_k', 0)), + 'top_k': int(body.get('top_k', 0)), 'min_length': int(body.get('min_length', 0)), - 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)), - 'num_beams': int(body.get('num_beams',1)), + 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)), + 'num_beams': int(body.get('num_beams', 1)), 'penalty_alpha': float(body.get('penalty_alpha', 0)), 'length_penalty': float(body.get('length_penalty', 1)), 'early_stopping': bool(body.get('early_stopping', False)), @@ -59,7 +60,7 @@ class Handler(BaseHTTPRequestHandler): } generator = generate_reply( - prompt, + prompt, generate_params, stopping_strings=body.get('stopping_strings', []), ) @@ -84,9 +85,9 @@ class Handler(BaseHTTPRequestHandler): def run_server(): server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) server = ThreadingHTTPServer(server_addr, Handler) - if shared.args.share: + if shared.args.share: try: - from flask_cloudflared import _run_cloudflared + from flask_cloudflared import _run_cloudflared public_url = _run_cloudflared(params['port'], params['port'] + 1) print(f'Starting KoboldAI compatible api at {public_url}/api') except ImportError: @@ -95,5 +96,6 @@ def run_server(): print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') server.serve_forever() + def setup(): Thread(target=run_server, daemon=True).start() diff --git a/extensions/character_bias/script.py b/extensions/character_bias/script.py index 35b38c0e..a92d0aef 100644 --- a/extensions/character_bias/script.py +++ b/extensions/character_bias/script.py @@ -5,14 +5,16 @@ params = { "bias string": " *I am so happy*", } + def input_modifier(string): """ This function is applied to your text inputs before they are fed into the model. - """ + """ return string + def output_modifier(string): """ This function is applied to the model outputs. @@ -20,6 +22,7 @@ def output_modifier(string): return string + def bot_prefix_modifier(string): """ This function is only applied in chat mode. It modifies @@ -27,11 +30,12 @@ def bot_prefix_modifier(string): behavior. """ - if params['activate'] == True: + if params['activate']: return f'{string} {params["bias string"].strip()} ' else: return string + def ui(): # Gradio elements activate = gr.Checkbox(value=params['activate'], label='Activate character bias') diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py index 2e8b184f..772952e1 100644 --- a/extensions/elevenlabs_tts/script.py +++ b/extensions/elevenlabs_tts/script.py @@ -20,16 +20,18 @@ user_info = None if not shared.args.no_stream: print("Please add --no-stream. This extension is not meant to be used with streaming.") raise ValueError - + # Check if the API is valid and refresh the UI accordingly. + + def check_valid_api(): - + global user, user_info, params user = ElevenLabsUser(params['api_key']) user_info = user._get_subscription_data() print('checking api') - if params['activate'] == False: + if not params['activate']: return gr.update(value='Disconnected') elif user_info is None: print('Incorrect API Key') @@ -37,24 +39,28 @@ def check_valid_api(): else: print('Got an API Key!') return gr.update(value='Connected') - + # Once the API is verified, get the available voices and update the dropdown list + + def refresh_voices(): - + global user, user_info - + your_voices = [None] if user_info is not None: for voice in user.get_available_voices(): your_voices.append(voice.initialName) - return gr.Dropdown.update(choices=your_voices) + return gr.Dropdown.update(choices=your_voices) else: return + def remove_surrounded_chars(string): # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' - return re.sub('\*[^\*]*?(\*|$)','',string) + return re.sub('\*[^\*]*?(\*|$)', '', string) + def input_modifier(string): """ @@ -64,16 +70,17 @@ def input_modifier(string): return string + def output_modifier(string): """ This function is applied to the model outputs. """ global params, wav_idx, user, user_info - - if params['activate'] == False: + + if not params['activate']: return string - elif user_info == None: + elif user_info is None: return string string = remove_surrounded_chars(string) @@ -84,7 +91,7 @@ def output_modifier(string): if string == '': string = 'empty reply, try regenerating' - + output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx)) voice = user.get_voices_by_name(params['selected_voice'])[0] audio_data = voice.generate_audio_bytes(string) @@ -94,6 +101,7 @@ def output_modifier(string): wav_idx += 1 return string + def ui(): # Gradio elements @@ -110,4 +118,4 @@ def ui(): voice.change(lambda x: params.update({'selected_voice': x}), voice, None) api_key.change(lambda x: params.update({'api_key': x}), api_key, None) connect.click(check_valid_api, [], connection_status) - connect.click(refresh_voices, [], voice) \ No newline at end of file + connect.click(refresh_voices, [], voice) diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index 5c47f0f1..8ffe8906 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -85,7 +85,7 @@ def select_character(evt: gr.SelectData): def ui(): with gr.Accordion("Character gallery", open=False): update = gr.Button("Refresh") - gr.HTML(value="") + gr.HTML(value="") gallery = gr.Dataset(components=[gr.HTML(visible=False)], label="", samples=generate_html(), @@ -93,4 +93,4 @@ def ui(): samples_per_page=50 ) update.click(generate_html, [], gallery) - gallery.select(select_character, None, gradio['character_menu']) \ No newline at end of file + gallery.select(select_character, None, gradio['character_menu']) diff --git a/extensions/google_translate/script.py b/extensions/google_translate/script.py index 68bc54b2..63226107 100644 --- a/extensions/google_translate/script.py +++ b/extensions/google_translate/script.py @@ -7,14 +7,16 @@ params = { language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'} + def input_modifier(string): """ This function is applied to your text inputs before they are fed into the model. - """ + """ return GoogleTranslator(source=params['language string'], target='en').translate(string) + def output_modifier(string): """ This function is applied to the model outputs. @@ -22,6 +24,7 @@ def output_modifier(string): return GoogleTranslator(source='en', target=params['language string']).translate(string) + def bot_prefix_modifier(string): """ This function is only applied in chat mode. It modifies @@ -31,6 +34,7 @@ def bot_prefix_modifier(string): return string + def ui(): # Finding the language name from the language code to use as the default value language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])] diff --git a/extensions/llama_prompts/script.py b/extensions/llama_prompts/script.py index e40ac5c0..da2196ed 100644 --- a/extensions/llama_prompts/script.py +++ b/extensions/llama_prompts/script.py @@ -4,12 +4,14 @@ import pandas as pd df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv") + def get_prompt_by_name(name): if name == 'None': return '' else: return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n') + def ui(): if not shared.is_chat(): choices = ['None'] + list(df['Prompt name']) diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py index df07ef2d..0c85c176 100644 --- a/extensions/sd_api_pictures/script.py +++ b/extensions/sd_api_pictures/script.py @@ -12,30 +12,33 @@ from PIL import Image torch._C._jit_set_profiling_mode(False) -# parameters which can be customized in settings.json of webui +# parameters which can be customized in settings.json of webui params = { 'enable_SD_api': False, 'address': 'http://127.0.0.1:7860', 'save_img': False, - 'SD_model': 'NeverEndingDream', # not really used right now + 'SD_model': 'NeverEndingDream', # not really used right now 'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful', 'negative_prompt': '(worst quality, low quality:1.3)', 'side_length': 512, 'restore_faces': False } -SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select +SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select -streaming_state = shared.args.no_stream # remember if chat streaming was enabled -picture_response = False # specifies if the next model response should appear as a picture +streaming_state = shared.args.no_stream # remember if chat streaming was enabled +picture_response = False # specifies if the next model response should appear as a picture pic_id = 0 + def remove_surrounded_chars(string): # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' - return re.sub('\*[^\*]*?(\*|$)','',string) + return re.sub('\*[^\*]*?(\*|$)', '', string) # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string + + def input_modifier(string): """ This function is applied to your text inputs before @@ -51,7 +54,7 @@ def input_modifier(string): lowstr = string.lower() # TODO: refactor out to separate handler and also replace detection with a regexp - if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found + if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found picture_response = True shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud shared.processing_message = "*Is sending a picture...*" @@ -62,6 +65,8 @@ def input_modifier(string): return string # Get and save the Stable Diffusion-generated picture + + def get_SD_pictures(description): global params, pic_id @@ -77,13 +82,13 @@ def get_SD_pictures(description): "restore_faces": params['restore_faces'], "negative_prompt": params['negative_prompt'] } - + response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload) r = response.json() visible_result = "" for img_str in r['images']: - image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0]))) + image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0]))) if params['save_img']: output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png') image.save(output_file.as_posix()) @@ -96,11 +101,13 @@ def get_SD_pictures(description): image_bytes = buffered.getvalue() img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode() visible_result = visible_result + f'{description}\n' - + return visible_result # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # and replace it with 'text' for the purposes of logging? + + def output_modifier(string): """ This function is applied to the model outputs. @@ -130,6 +137,7 @@ def output_modifier(string): shared.args.no_stream = streaming_state return image + "\n" + text + def bot_prefix_modifier(string): """ This function is only applied in chat mode. It modifies @@ -139,10 +147,12 @@ def bot_prefix_modifier(string): return string + def force_pic(): global picture_response picture_response = True + def ui(): # Gradio elements @@ -153,7 +163,7 @@ def ui(): save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir') with gr.Column(): address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address') - + with gr.Row(): force_btn = gr.Button("Force the next response to be a picture") generate_now_btn = gr.Button("Generate an image response to the input") @@ -162,9 +172,9 @@ def ui(): prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') with gr.Row(): negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt') - dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions') + dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions') # model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model') - + # Event functions to update the parameters in the backend enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None) save_img.change(lambda x: params.update({"save_img": x}), save_img, None) @@ -176,4 +186,4 @@ def ui(): force_btn.click(force_pic) generate_now_btn.click(force_pic) - generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) \ No newline at end of file + generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index d2401dff..678592f5 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -17,11 +17,13 @@ input_hijack = { processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") + def caption_image(raw_image): inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) out = model.generate(**inputs, max_new_tokens=100) return processor.decode(out[0], skip_special_tokens=True) + def generate_chat_picture(picture, name1, name2): text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history @@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2): visible_text = f'{text}' return text, visible_text + def ui(): picture_select = gr.Image(label='Send a picture', type='pil') @@ -42,4 +45,4 @@ def ui(): picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) # Clear the picture from the upload field - picture_select.upload(lambda : None, [], [picture_select], show_progress=False) + picture_select.upload(lambda: None, [], [picture_select], show_progress=False) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 572947ad..3f42e5c6 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -17,9 +17,11 @@ from quant import make_quant def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): - config = AutoConfig.from_pretrained(model) + def noop(*args, **kwargs): pass + + config = AutoConfig.from_pretrained(model) torch.nn.init.kaiming_uniform_ = noop torch.nn.init.uniform_ = noop torch.nn.init.normal_ = noop @@ -34,11 +36,11 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc for name in exclude_layers: if name in layers: del layers[name] - + gptq_args = inspect.getfullargspec(make_quant).args make_quant_kwargs = { - 'module': model, + 'module': model, 'names': layers, 'bits': wbits, } @@ -48,7 +50,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc make_quant_kwargs['faster'] = faster_kernel if 'kernel_switch_threshold' in gptq_args: make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold - + make_quant(**make_quant_kwargs) del layers @@ -56,14 +58,15 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc print('Loading model ...') if checkpoint.endswith('.safetensors'): from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint), strict = False) + model.load_state_dict(safe_load(checkpoint), strict=False) else: - model.load_state_dict(torch.load(checkpoint), strict = False) + model.load_state_dict(torch.load(checkpoint), strict=False) model.seqlen = 2048 print('Done.') return model + def load_quantized(model_name): if not shared.args.model_type: # Try to determine model type from model name @@ -114,7 +117,7 @@ def load_quantized(model_name): pt_model = f'{model_name}-{shared.args.wbits}bit' # Try to find the .safetensors or .pt both in the model dir and in the subfolder - for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]: + for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]: if path.exists(): print(f"Found {path}") pt_path = path @@ -133,7 +136,7 @@ def load_quantized(model_name): # accelerate offload (doesn't work properly) if shared.args.gpu_memory: - memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory)) + memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory)) max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB' max_memory = {} for i in range(len(memory_map)): diff --git a/modules/LoRA.py b/modules/LoRA.py index 8c30e609..17dd7229 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -13,6 +13,7 @@ def reload_model(): clear_torch_cache() shared.model, shared.tokenizer = load_model(shared.model_name) + def add_lora_to_model(lora_name): # If a LoRA had been previously loaded, or if we want @@ -27,10 +28,10 @@ def add_lora_to_model(lora_name): if not shared.args.cpu: params['dtype'] = shared.model.dtype if hasattr(shared.model, "hf_device_map"): - params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()} + params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()} elif shared.args.load_in_8bit: params['device_map'] = {'': 0} - + shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params) if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() diff --git a/modules/RWKV.py b/modules/RWKV.py index 10c4c366..0405230e 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -10,7 +10,7 @@ from modules.callbacks import Iteratorize np.set_printoptions(precision=4, suppress=True, linewidth=200) os.environ['RWKV_JIT_ON'] = '1' -os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster) +os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster) from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS @@ -36,13 +36,13 @@ class RWKVModel: def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None): args = PIPELINE_ARGS( - temperature = temperature, - top_p = top_p, - top_k = top_k, - alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3) - alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3) - token_ban = token_ban, # ban the generation of some tokens - token_stop = token_stop + temperature=temperature, + top_p=top_p, + top_k=top_k, + alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3) + alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3) + token_ban=token_ban, # ban the generation of some tokens + token_stop=token_stop ) return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) @@ -54,6 +54,7 @@ class RWKVModel: reply += token yield reply + class RWKVTokenizer: def __init__(self): pass diff --git a/modules/api.py b/modules/api.py index 26249fd7..f18ad4cf 100644 --- a/modules/api.py +++ b/modules/api.py @@ -28,6 +28,7 @@ def generate_reply_wrapper(string): for i in generate_reply(params[0], generate_params): yield i + def create_apis(): t1 = gr.Textbox(visible=False) t2 = gr.Textbox(visible=False) diff --git a/modules/callbacks.py b/modules/callbacks.py index 945b8c37..51ecbdd7 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): return True return False + class Stream(transformers.StoppingCriteria): def __init__(self, callback_func=None): self.callback_func = callback_func @@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria): self.callback_func(input_ids[0]) return False + class Iteratorize: """ @@ -47,8 +49,8 @@ class Iteratorize: """ def __init__(self, func, kwargs={}, callback=None): - self.mfunc=func - self.c_callback=callback + self.mfunc = func + self.c_callback = callback self.q = Queue() self.sentinel = object() self.kwargs = kwargs @@ -80,7 +82,7 @@ class Iteratorize: return self def __next__(self): - obj = self.q.get(True,None) + obj = self.q.get(True, None) if obj is self.sentinel: raise StopIteration else: @@ -96,6 +98,7 @@ class Iteratorize: self.stop_now = True clear_torch_cache() + def clear_torch_cache(): gc.collect() if not shared.args.cpu: diff --git a/modules/chat.py b/modules/chat.py index 36932641..e1867059 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -23,12 +23,11 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False - rows = [f"{context.strip()}\n"] # Finding the maximum prompt size if shared.soft_prompt: - chat_prompt_size -= shared.soft_prompt_tensor.shape[1] + chat_prompt_size -= shared.soft_prompt_tensor.shape[1] max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size) if is_instruct: @@ -38,7 +37,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat prefix1 = f"{name1}: " prefix2 = f"{name2}: " - i = len(shared.history['internal'])-1 + i = len(shared.history['internal']) - 1 while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") string = shared.history['internal'][i][0] @@ -68,6 +67,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat else: return prompt + def extract_message_from_reply(reply, name1, name2, stop_at_newline): next_character_found = False @@ -87,7 +87,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): # is completed, trim it if not next_character_found: for string in [f"\n{name1}:", f"\n{name2}:"]: - for j in range(len(string)-1, 0, -1): + for j in range(len(string) - 1, 0, -1): if reply[-j:] == string[:j]: reply = reply[:-j] break @@ -98,12 +98,13 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): reply = fix_newlines(reply) return reply, next_character_found + def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): if mode == 'instruct': stopping_strings = [f"\n{name1}", f"\n{name2}"] else: stopping_strings = [f"\n{name1}:", f"\n{name2}:"] - + eos_token = '\n' if generate_state['stop_at_newline'] else None name1_original = name1 if 'pygmalion' in shared.model_name.lower(): @@ -113,7 +114,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu visible_text = None custom_generate_chat_prompt = None for extension, _ in extensions_module.iterator(): - if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: + if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: extension.input_hijack['state'] = False text, visible_text = extension.input_hijack['value'] if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): @@ -131,7 +132,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu # Yield *Is typing...* if not regenerate: - yield shared.history['visible']+[[visible_text, shared.processing_message]] + yield shared.history['visible'] + [[visible_text, shared.processing_message]] # Generate cumulative_reply = '' @@ -167,12 +168,13 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu yield shared.history['visible'] + def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): if mode == 'instruct': stopping_strings = [f"\n{name1}", f"\n{name2}"] else: stopping_strings = [f"\n{name1}:", f"\n{name2}:"] - + eos_token = '\n' if generate_state['stop_at_newline'] else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -197,10 +199,12 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o yield reply + def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): yield chat_html_wrapper(history, name1, name2, mode) + def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) @@ -208,11 +212,12 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' - yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode) + yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode) for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True): shared.history['visible'][-1] = [last_visible[0], history[-1][1]] yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) + def remove_last_message(name1, name2, mode): if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': last = shared.history['visible'].pop() @@ -222,12 +227,14 @@ def remove_last_message(name1, name2, mode): return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0] + def send_last_reply_to_input(): if len(shared.history['internal']) > 0: return shared.history['internal'][-1][1] else: return '' + def replace_last_reply(text, name1, name2, mode): if len(shared.history['visible']) > 0: shared.history['visible'][-1][1] = text @@ -235,9 +242,11 @@ def replace_last_reply(text, name1, name2, mode): return chat_html_wrapper(shared.history['visible'], name1, name2, mode) + def clear_html(): return chat_html_wrapper([], "", "") + def clear_chat_log(name1, name2, greeting, mode): shared.history['visible'] = [] shared.history['internal'] = [] @@ -248,9 +257,11 @@ def clear_chat_log(name1, name2, greeting, mode): return chat_html_wrapper(shared.history['visible'], name1, name2, mode) + def redraw_html(name1, name2, mode): return chat_html_wrapper(shared.history['visible'], name1, name2, mode) + def tokenize_dialogue(dialogue, name1, name2, mode): history = [] @@ -263,8 +274,8 @@ def tokenize_dialogue(dialogue, name1, name2, mode): return history messages = [] - for i in range(len(idx)-1): - messages.append(dialogue[idx[i]:idx[i+1]].strip()) + for i in range(len(idx) - 1): + messages.append(dialogue[idx[i]:idx[i + 1]].strip()) messages.append(dialogue[idx[-1]:].strip()) entry = ['', ''] @@ -282,12 +293,13 @@ def tokenize_dialogue(dialogue, name1, name2, mode): for column in row: print("\n") for line in column.strip().split('\n'): - print("| "+line+"\n") + print("| " + line + "\n") print("|\n") print("------------------------------") return history + def save_history(timestamp=True): if timestamp: fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" @@ -299,6 +311,7 @@ def save_history(timestamp=True): f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) return Path(f'logs/{fname}') + def load_history(file, name1, name2): file = file.decode('utf-8') try: @@ -313,20 +326,22 @@ def load_history(file, name1, name2): elif 'chat' in j: shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']] if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'): - shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)] + shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)] shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'][0][0] = '' else: - shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)] + shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)] shared.history['visible'] = copy.deepcopy(shared.history['internal']) except: shared.history['internal'] = tokenize_dialogue(file, name1, name2) shared.history['visible'] = copy.deepcopy(shared.history['internal']) + def replace_character_names(text, name1, name2): text = text.replace('{{user}}', name1).replace('{{char}}', name2) return text.replace('', name1).replace('', name2) + def build_pygmalion_style_context(data): context = "" if 'char_persona' in data and data['char_persona'] != '': @@ -336,6 +351,7 @@ def build_pygmalion_style_context(data): context = f"{context.strip()}\n\n" return context + def generate_pfp_cache(character): cache_folder = Path("cache") if not cache_folder.exists(): @@ -348,6 +364,7 @@ def generate_pfp_cache(character): return img return None + def load_character(character, name1, name2, mode): shared.character = character shared.history['internal'] = [] @@ -387,13 +404,13 @@ def load_character(character, name1, name2, mode): if 'example_dialogue' in data: context += f"{data['example_dialogue'].strip()}\n" if greeting_field in data: - greeting = data[greeting_field] + greeting = data[greeting_field] if 'end_of_turn' in data: - end_of_turn = data['end_of_turn'] + end_of_turn = data['end_of_turn'] else: context = shared.settings['context'] name2 = shared.settings['name2'] - greeting = shared.settings['greeting'] + greeting = shared.settings['greeting'] end_of_turn = shared.settings['end_of_turn'] if Path(f'logs/{shared.character}_persistent.json').exists(): @@ -404,9 +421,11 @@ def load_character(character, name1, name2, mode): return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) + def load_default_history(name1, name2): load_character("None", name1, name2, "chat") + def upload_character(json_file, img, tavern=False): json_file = json_file if type(json_file) == str else json_file.decode('utf-8') data = json.loads(json_file) @@ -425,6 +444,7 @@ def upload_character(json_file, img, tavern=False): print(f'New character saved to "characters/{outfile_name}.json".') return outfile_name + def upload_tavern_character(img, name1, name2): _img = Image.open(io.BytesIO(img)) _img.getexif() @@ -433,12 +453,13 @@ def upload_tavern_character(img, name1, name2): _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} return upload_character(json.dumps(_json), img, tavern=True) + def upload_your_profile_picture(img, name1, name2, mode): cache_folder = Path("cache") if not cache_folder.exists(): cache_folder.mkdir() - if img == None: + if img is None: if Path("cache/pfp_me.png").exists(): Path("cache/pfp_me.png").unlink() else: diff --git a/modules/extensions.py b/modules/extensions.py index fe6a3945..8f435802 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -9,6 +9,7 @@ state = {} available_extensions = [] setup_called = set() + def load_extensions(): global state for i, name in enumerate(shared.args.extensions): @@ -23,12 +24,16 @@ def load_extensions(): traceback.print_exc() # This iterator returns the extensions in the order specified in the command-line + + def iterator(): - for name in sorted(state, key=lambda x : state[x][1]): + for name in sorted(state, key=lambda x: state[x][1]): if state[name][0] == True: yield eval(f"extensions.{name}.script"), name # Extension functions that map string -> string + + def apply_extensions(text, typ): for extension, _ in iterator(): if typ == "input" and hasattr(extension, "input_modifier"): @@ -39,6 +44,7 @@ def apply_extensions(text, typ): text = extension.bot_prefix_modifier(text) return text + def create_extensions_block(): global setup_called diff --git a/modules/html_generator.py b/modules/html_generator.py index 448c20c2..6e20566c 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -24,6 +24,7 @@ with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f: instruct_css = f.read() + def fix_newlines(string): string = string.replace('\n', '\n\n') string = re.sub(r"\n{3,}", "\n\n", string) @@ -31,6 +32,8 @@ def fix_newlines(string): return string # This could probably be generalized and improved + + def convert_to_markdown(string): string = string.replace('\\begin{code}', '```') string = string.replace('\\end{code}', '```') @@ -38,13 +41,15 @@ def convert_to_markdown(string): string = string.replace('\\end{blockquote}', '') string = re.sub(r"(.)```", r"\1\n```", string) string = fix_newlines(string) - return markdown.markdown(string, extensions=['fenced_code']) + return markdown.markdown(string, extensions=['fenced_code']) + def generate_basic_html(string): string = convert_to_markdown(string) string = f'
{string}
' return string + def process_post(post, c): t = post.split('\n') number = t[0].split(' ')[1] @@ -59,6 +64,7 @@ def process_post(post, c): src = f'Anonymous No.{number}\n{src}' return src + def generate_4chan_html(f): posts = [] post = '' @@ -84,7 +90,7 @@ def generate_4chan_html(f): posts[i] = f'
{posts[i]}
\n' else: posts[i] = f'
{posts[i]}
\n' - + output = '' output += f'
' for post in posts: @@ -98,13 +104,15 @@ def generate_4chan_html(f): return output + def make_thumbnail(image): - image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS) + image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS) if image.size[1] > 470: image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS) return image + def get_image_cache(path): cache_folder = Path("cache") if not cache_folder.exists(): @@ -119,9 +127,10 @@ def get_image_cache(path): return image_cache[path][1] + def generate_instruct_html(history): output = f'
' - for i,_row in enumerate(history[::-1]): + for i, _row in enumerate(history[::-1]): row = [convert_to_markdown(entry) for entry in _row] output += f""" @@ -134,7 +143,7 @@ def generate_instruct_html(history):
""" - if len(row[0]) == 0: # don't display empty user messages + if len(row[0]) == 0: # don't display empty user messages continue output += f""" @@ -151,6 +160,7 @@ def generate_instruct_html(history): return output + def generate_cai_chat_html(history, name1, name2, reset_cache=False): output = f'
' @@ -159,7 +169,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False): img_bot = f'' if Path("cache/pfp_character.png").exists() else '' img_me = f'' if Path("cache/pfp_me.png").exists() else '' - for i,_row in enumerate(history[::-1]): + for i, _row in enumerate(history[::-1]): row = [convert_to_markdown(entry) for entry in _row] output += f""" @@ -178,7 +188,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
""" - if len(row[0]) == 0: # don't display empty user messages + if len(row[0]) == 0: # don't display empty user messages continue output += f""" @@ -200,9 +210,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False): output += "
" return output + def generate_chat_html(history, name1, name2): return generate_cai_chat_html(history, name1, name2) + def chat_html_wrapper(history, name1, name2, mode, reset_cache=False): if mode == "cai-chat": return generate_cai_chat_html(history, name1, name2, reset_cache) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 4f491329..9461db10 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -50,9 +50,9 @@ class LlamaCppModel: params.top_k = top_k params.temp = temperature params.repeat_penalty = repetition_penalty - #params.repeat_last_n = repeat_last_n + # params.repeat_last_n = repeat_last_n - #self.model.params = params + # self.model.params = params self.model.add_bos() self.model.update_input(context) diff --git a/modules/llamacpp_model_alternative.py b/modules/llamacpp_model_alternative.py index 40576113..8fea2ab4 100644 --- a/modules/llamacpp_model_alternative.py +++ b/modules/llamacpp_model_alternative.py @@ -1,13 +1,11 @@ ''' -Based on +Based on https://github.com/abetlen/llama-cpp-python Documentation: https://abetlen.github.io/llama-cpp-python/ ''' -import multiprocessing - from llama_cpp import Llama from modules import shared @@ -31,7 +29,7 @@ class LlamaCppModel: self.model = Llama(**params) # This is ugly, but the model and the tokenizer are the same object in this library. - return result, result + return result, result def encode(self, string): if type(string) is str: diff --git a/modules/models.py b/modules/models.py index 1bf6fc37..5e2b0989 100644 --- a/modules/models.py +++ b/modules/models.py @@ -34,7 +34,7 @@ if shared.args.deepspeed: torch.cuda.set_device(local_rank) deepspeed.init_distributed() ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir) - dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration + dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration def load_model(model_name): @@ -83,7 +83,7 @@ def load_model(model_name): elif shared.args.deepspeed: model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] - model.module.eval() # Inference + model.module.eval() # Inference print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") # RMKV model (not on HuggingFace) @@ -132,7 +132,7 @@ def load_model(model_name): params["torch_dtype"] = torch.float16 if shared.args.gpu_memory: - memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory)) + memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory)) max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB' max_memory = {} for i in range(len(memory_map)): @@ -140,13 +140,13 @@ def load_model(model_name): max_memory['cpu'] = max_cpu_memory params['max_memory'] = max_memory elif shared.args.auto_devices: - total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024)) - suggestion = round((total_mem-1000) / 1000) * 1000 + total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)) + suggestion = round((total_mem - 1000) / 1000) * 1000 if total_mem - suggestion < 800: suggestion -= 1000 - suggestion = int(round(suggestion/1000)) + suggestion = int(round(suggestion / 1000)) print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") - + max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'} params['max_memory'] = max_memory @@ -161,10 +161,10 @@ def load_model(model_name): model = AutoModelForCausalLM.from_config(config) model.tie_weights() params['device_map'] = infer_auto_device_map( - model, - dtype=torch.int8, + model, + dtype=torch.int8, max_memory=params['max_memory'], - no_split_module_classes = model._no_split_modules + no_split_module_classes=model._no_split_modules ) model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) @@ -181,6 +181,7 @@ def load_model(model_name): print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer + def load_soft_prompt(name): if name == 'None': shared.soft_prompt = False diff --git a/modules/shared.py b/modules/shared.py index 902d7609..7ff1ca28 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -61,6 +61,7 @@ settings = { } } + def str2bool(v): if isinstance(v, bool): return v @@ -71,7 +72,8 @@ def str2bool(v): else: raise argparse.ArgumentTypeError('Boolean value expected.') -parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) + +parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) # Basic settings parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') @@ -145,5 +147,6 @@ if args.cai_chat: print("Warning: --cai-chat is deprecated. Use --chat instead.") args.chat = True + def is_chat(): return args.chat diff --git a/modules/text_generation.py b/modules/text_generation.py index b8885abe..5fead483 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -16,11 +16,12 @@ from modules.models import local_rank def get_max_prompt_length(tokens): - max_length = 2048-tokens + max_length = 2048 - tokens if shared.soft_prompt: max_length -= shared.soft_prompt_tensor.shape[1] return max_length + def encode(prompt, tokens_to_generate=0, add_special_tokens=True): if any((shared.is_RWKV, shared.is_llamacpp)): input_ids = shared.tokenizer.encode(str(prompt)) @@ -30,7 +31,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: - input_ids = input_ids[:,1:] + input_ids = input_ids[:, 1:] if shared.args.cpu: return input_ids @@ -44,6 +45,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): else: return input_ids.cuda() + def decode(output_ids): # Open Assistant relies on special tokens like <|endoftext|> if re.match('.*(oasst|galactica)-*', shared.model_name.lower()): @@ -53,14 +55,17 @@ def decode(output_ids): reply = reply.replace(r'<|endoftext|>', '') return reply + def generate_softprompt_input_tensors(input_ids): inputs_embeds = shared.model.transformer.wte(input_ids) inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1) filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device) - #filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens + # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens return inputs_embeds, filler_input_ids # Removes empty replies from gpt4chan outputs + + def fix_gpt4chan(s): for i in range(10): s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) @@ -69,6 +74,8 @@ def fix_gpt4chan(s): return s # Fix the LaTeX equations in galactica + + def fix_galactica(s): s = s.replace(r'\[', r'$') s = s.replace(r'\]', r'$') @@ -79,6 +86,7 @@ def fix_galactica(s): s = re.sub(r"\n{3,}", "\n\n", s) return s + def formatted_outputs(reply, model_name): if not shared.is_chat(): if 'galactica' in model_name.lower(): @@ -92,20 +100,24 @@ def formatted_outputs(reply, model_name): else: return reply + def clear_torch_cache(): gc.collect() if not shared.args.cpu: torch.cuda.empty_cache() + def set_manual_seed(seed): if seed != -1: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + def stop_everything_event(): shared.stop_everything = True + def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): clear_torch_cache() set_manual_seed(generate_state['seed']) @@ -128,7 +140,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[] try: if shared.args.no_stream: reply = shared.model.generate(context=question, **generate_params) - output = original_question+reply + output = original_question + reply if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") yield formatted_outputs(reply, shared.model_name) @@ -139,7 +151,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[] # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, **generate_params): - output = original_question+reply + output = original_question + reply if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") yield formatted_outputs(reply, shared.model_name) @@ -240,7 +252,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[] # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' else: - for i in range(generate_state['max_new_tokens']//8+1): + for i in range(generate_state['max_new_tokens'] // 8 + 1): clear_torch_cache() with torch.no_grad(): output = shared.model.generate(**generate_params)[0] @@ -271,6 +283,6 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[] finally: t1 = time.time() original_tokens = len(original_input_ids[0]) - new_tokens = len(output)-original_tokens + new_tokens = len(output) - original_tokens print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})") return diff --git a/modules/training.py b/modules/training.py index 220428b1..9880cf00 100644 --- a/modules/training.py +++ b/modules/training.py @@ -19,9 +19,11 @@ CURRENT_STEPS = 0 MAX_STEPS = 0 CURRENT_GRADIENT_ACCUM = 1 + def get_dataset(path: str, ext: str): return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower) + def create_train_interface(): with gr.Tab('Train LoRA', elem_id='lora-train-tab'): lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file") @@ -44,16 +46,16 @@ def create_train_interface(): with gr.Tab(label="Formatted Dataset"): with gr.Row(): dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.') - ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') + ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.') - ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') + ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.') - ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button') + ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button') with gr.Tab(label="Raw Text File"): with gr.Row(): raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.') - ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button') + ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button') with gr.Row(): overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.') newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.') @@ -67,10 +69,12 @@ def create_train_interface(): cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output]) stop_button.click(do_interrupt, [], [], cancels=[], queue=False) + def do_interrupt(): global WANT_INTERRUPT WANT_INTERRUPT = True + class Callbacks(transformers.TrainerCallback): def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): global CURRENT_STEPS, MAX_STEPS @@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback): if WANT_INTERRUPT: control.should_epoch_stop = True control.should_training_stop = True + def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): global CURRENT_STEPS CURRENT_STEPS += 1 @@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback): control.should_epoch_stop = True control.should_training_stop = True + def clean_path(base_path: str, path: str): """"Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. @@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str): return path return f'{Path(base_path).absolute()}/{path}' + def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int): global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM @@ -124,7 +131,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int elif not shared.args.load_in_8bit: yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*" print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.") - time.sleep(2) # Give it a moment for the message to show in UI before continuing + time.sleep(2) # Give it a moment for the message to show in UI before continuing if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: yield "Cannot input zeroes." @@ -148,7 +155,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file: raw_text = file.read() tokens = shared.tokenizer.encode(raw_text) - del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM + del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM tokens = list(split_chunks(tokens, cutoff_len - overlap_len)) for i in range(1, len(tokens)): @@ -197,18 +204,18 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int else: eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json')) eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt) - + # == Start prepping the model itself == if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): print("Getting model ready...") prepare_model_for_int8_training(shared.model) - + print("Prepping for training...") config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, # TODO: Should target_modules be configurable? - target_modules=[ "q_proj", "v_proj" ], + target_modules=["q_proj", "v_proj"], lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM" @@ -289,7 +296,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int timer_info = f"`{its:.2f}` it/s" else: timer_info = f"`{1.0/its:.2f}` s/it" - total_time_estimate = (1.0/its) * (MAX_STEPS) + total_time_estimate = (1.0 / its) * (MAX_STEPS) yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining" print("Training complete, saving...") @@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int print("Training complete!") yield f"Done! LoRA saved to `{lora_name}`" + def split_chunks(arr, step): for i in range(0, len(arr), step): yield arr[i:i + step] + def cut_chunk_for_newline(chunk: str, max_length: int): if '\n' not in chunk: return chunk @@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int): chunk = chunk[:last_newline] return chunk + def format_time(seconds: float): if seconds < 120: return f"`{seconds:.0f}` seconds" diff --git a/modules/ui.py b/modules/ui.py index 80bd7c1c..def1faaf 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f: with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: chat_js = f.read() + class ToolButton(gr.Button, gr.components.FormComponent): """Small button with single emoji as text, fits inside gradio forms""" @@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent): def get_block_name(self): return "button" + def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() diff --git a/server.py b/server.py index 4ba5ba82..4fdeb80b 100644 --- a/server.py +++ b/server.py @@ -34,15 +34,18 @@ if settings_file is not None: for item in new_settings: shared.settings[item] = new_settings[item] + def get_available_models(): if shared.args.flexgen: return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + def get_available_presets(): return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower) + def get_available_prompts(): prompts = [] prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True) @@ -50,10 +53,12 @@ def get_available_prompts(): prompts += ['None'] return prompts + def get_available_characters(): paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower) + def get_available_instruction_templates(): path = "characters/instruction-following" paths = [] @@ -61,19 +66,24 @@ def get_available_instruction_templates(): paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower) + def get_available_extensions(): - return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) + return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) + def get_available_softprompts(): return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower) + def get_available_loras(): return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + def unload_model(): shared.model = shared.tokenizer = None clear_torch_cache() + def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model @@ -84,10 +94,12 @@ def load_model_wrapper(selected_model): return selected_model + def load_lora_wrapper(selected_lora): add_lora_to_model(selected_lora) return selected_lora + def load_preset_values(preset_menu, state, return_dict=False): generate_params = { 'do_sample': True, @@ -118,6 +130,7 @@ def load_preset_values(preset_menu, state, return_dict=False): state.update(generate_params) return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + def upload_soft_prompt(file): with zipfile.ZipFile(io.BytesIO(file)) as zf: zf.extract('meta.json') @@ -130,12 +143,14 @@ def upload_soft_prompt(file): return name + def save_prompt(text): fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt" with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: f.write(text) return f"Saved to prompts/{fname}" + def load_prompt(fname): if fname in ['None', '']: return '' @@ -146,12 +161,13 @@ def load_prompt(fname): text = text[:-1] return text + def create_prompt_menus(): with gr.Row(): with gr.Column(): with gr.Row(): shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt') - ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button') + ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button') with gr.Column(): with gr.Column(): @@ -161,20 +177,22 @@ def create_prompt_menus(): shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) + def create_model_menus(): with gr.Row(): with gr.Column(): with gr.Row(): shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') - ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button') + ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button') with gr.Column(): with gr.Row(): shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') - ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') + ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) + def create_settings_menus(default_preset): generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']: @@ -185,7 +203,7 @@ def create_settings_menus(default_preset): with gr.Column(): with gr.Row(): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') - ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') + ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button') with gr.Column(): shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') @@ -196,12 +214,12 @@ def create_settings_menus(default_preset): with gr.Row(): with gr.Column(): shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') - shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p') - shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k') - shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p') + shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p') + shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k') + shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p') with gr.Column(): - shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty') - shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty') + shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty') + shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty') shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream) shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') @@ -209,7 +227,6 @@ def create_settings_menus(default_preset): with gr.Box(): gr.Markdown('Contrastive search') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') - with gr.Box(): gr.Markdown('Beam search (uses a lot of VRAM)') with gr.Row(): @@ -219,11 +236,10 @@ def create_settings_menus(default_preset): shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - with gr.Accordion('Soft prompt', open=False): with gr.Row(): shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') - ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button') + ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': get_available_softprompts()}, 'refresh-button') gr.Markdown('Upload a soft prompt (.zip format):') with gr.Row(): @@ -233,6 +249,7 @@ def create_settings_menus(default_preset): shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) + def set_interface_arguments(interface_mode, extensions, bool_active): modes = ["default", "notebook", "chat", "cai_chat"] cmd_list = vars(shared.args) @@ -251,6 +268,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active): shared.need_restart = True + available_models = get_available_models() available_presets = get_available_presets() available_characters = get_available_characters() @@ -284,7 +302,7 @@ else: for i, model in enumerate(available_models): print(f'{i+1}. {model}') print(f'\nWhich one do you want to load? 1-{len(available_models)}\n') - i = int(input())-1 + i = int(input()) - 1 print() shared.model_name = available_models[i] shared.model, shared.tokenizer = load_model(shared.model_name) @@ -297,15 +315,15 @@ if shared.lora_name != "None": default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]) else: default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]) -title ='Text generation web UI' +title = 'Text generation web UI' + def create_interface(): - gen_events = [] if shared.args.extensions is not None and len(shared.args.extensions) > 0: extensions_module.load_extensions() - with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: + with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: if shared.is_chat(): shared.gradio['Chat input'] = gr.State() with gr.Tab("Text generation", elem_id="main"): @@ -342,7 +360,7 @@ def create_interface(): shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None) with gr.Row(): shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') - ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') + ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button') with gr.Row(): with gr.Tab('Chat history'): @@ -399,11 +417,11 @@ def create_interface(): # Clear history with confirmation clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] - shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) - shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) + shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) + shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']) - shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) - shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']) + shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) + shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) @@ -412,10 +430,10 @@ def create_interface(): # Clearing stuff and saving the history for i in ['Generate', 'Regenerate', 'Replace last reply']: shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) - shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + shared.gradio[i].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) + shared.gradio['Clear history-confirm'].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) - shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) @@ -430,7 +448,7 @@ def create_interface(): shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") - shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) + shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: @@ -526,7 +544,7 @@ def create_interface(): shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary") shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None) - shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500); return []}') + shared.gradio['reset_interface'].click(lambda: None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500); return []}') if shared.args.extensions is not None: extensions_module.create_extensions_block() @@ -562,6 +580,7 @@ def create_interface(): else: shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) + create_interface() while True: