diff --git a/convert-to-flexgen.py b/convert-to-flexgen.py index a5b127a6..917f023c 100644 --- a/convert-to-flexgen.py +++ b/convert-to-flexgen.py @@ -3,6 +3,7 @@ Converts a transformers model to a format compatible with flexgen. ''' + import argparse import os from pathlib import Path @@ -10,9 +11,8 @@ from pathlib import Path import numpy as np import torch from tqdm import tqdm -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer - +from transformers import AutoModelForCausalLM, AutoTokenizer + parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") args = parser.parse_args() @@ -31,7 +31,6 @@ def disable_torch_init(): torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) - def restore_torch_init(): """Rollback the change made by disable_torch_init.""" import torch diff --git a/convert-to-safetensors.py b/convert-to-safetensors.py index 8c12dec8..63baaa97 100644 --- a/convert-to-safetensors.py +++ b/convert-to-safetensors.py @@ -10,13 +10,13 @@ Based on the original script by 81300: https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303 ''' + import argparse from pathlib import Path import torch -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer - +from transformers import AutoModelForCausalLM, AutoTokenizer + parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).') diff --git a/modules/bot_picture.py b/modules/bot_picture.py index 72e87c56..dd4d73eb 100644 --- a/modules/bot_picture.py +++ b/modules/bot_picture.py @@ -1,6 +1,5 @@ import torch -from transformers import BlipForConditionalGeneration -from transformers import BlipProcessor +from transformers import BlipForConditionalGeneration, BlipProcessor processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") diff --git a/modules/chat.py b/modules/chat.py index 63372813..fa9ee4f3 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -7,13 +7,12 @@ from datetime import datetime from io import BytesIO from pathlib import Path +from PIL import Image + import modules.shared as shared from modules.extensions import apply_extensions from modules.html_generator import generate_chat_html -from modules.text_generation import encode -from modules.text_generation import generate_reply -from modules.text_generation import get_max_prompt_length -from PIL import Image +from modules.text_generation import encode, generate_reply, get_max_prompt_length if shared.args.picture and (shared.args.cai_chat or shared.args.chat): import modules.bot_picture as bot_picture diff --git a/modules/extensions.py b/modules/extensions.py index eeb10bcd..1094fc4d 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,6 +1,5 @@ -import modules.shared as shared - import extensions +import modules.shared as shared extension_state = {} available_extensions = [] diff --git a/modules/html_generator.py b/modules/html_generator.py index ed9996fc..6e1fb8ac 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -3,6 +3,7 @@ This is a library for formatting GPT-4chan and chat outputs as nice HTML. ''' + import base64 import os import re diff --git a/modules/models.py b/modules/models.py index 85e1362c..efa3eb25 100644 --- a/modules/models.py +++ b/modules/models.py @@ -4,23 +4,27 @@ import time import zipfile from pathlib import Path -import modules.shared as shared import numpy as np import torch import transformers -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer + +import modules.shared as shared transformers.logging.set_verbosity_error() local_rank = None if shared.args.flexgen: - from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, get_opt_config) + from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy, + TorchDevice, TorchDisk, TorchMixedDevice, + get_opt_config) if shared.args.deepspeed: import deepspeed - from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled + from transformers.deepspeed import (HfDeepSpeedConfig, + is_deepspeed_zero3_enabled) + from modules.deepspeed_parameters import generate_ds_config # Distributed setup diff --git a/modules/stopping_criteria.py b/modules/stopping_criteria.py index 3e403ffe..44a631b3 100644 --- a/modules/stopping_criteria.py +++ b/modules/stopping_criteria.py @@ -4,9 +4,11 @@ This code was copied from https://github.com/PygmalionAI/gradio-ui/ ''' + import torch import transformers + class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): def __init__(self, sentinel_token_ids: torch.LongTensor, diff --git a/modules/text_generation.py b/modules/text_generation.py index 3e6cb543..d0204102 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -1,16 +1,17 @@ import re import time -import modules.shared as shared import numpy as np import torch import transformers +from tqdm import tqdm + +import modules.shared as shared from modules.extensions import apply_extensions -from modules.html_generator import generate_4chan_html -from modules.html_generator import generate_basic_html +from modules.html_generator import generate_4chan_html, generate_basic_html from modules.models import local_rank from modules.stopping_criteria import _SentinelTokenStoppingCriteria -from tqdm import tqdm + def get_max_prompt_length(tokens): max_length = 2048-tokens diff --git a/server.py b/server.py index 0cd21fcd..d5439710 100644 --- a/server.py +++ b/server.py @@ -14,12 +14,9 @@ import modules.chat as chat import modules.extensions as extensions_module import modules.shared as shared import modules.ui as ui -from modules.extensions import extension_state -from modules.extensions import load_extensions -from modules.extensions import update_extensions_parameters +from modules.extensions import extension_state, load_extensions, update_extensions_parameters from modules.html_generator import generate_chat_html -from modules.models import load_model -from modules.models import load_soft_prompt +from modules.models import load_model, load_soft_prompt from modules.text_generation import generate_reply if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: