Improve the imports

This commit is contained in:
oobabooga 2023-02-23 14:41:42 -03:00
parent 364529d0c7
commit 7224343a70
10 changed files with 30 additions and 29 deletions

View file

@ -3,6 +3,7 @@
Converts a transformers model to a format compatible with flexgen.
'''
import argparse
import os
from pathlib import Path
@ -10,8 +11,7 @@ from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
@ -31,7 +31,6 @@ def disable_torch_init():
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init():
"""Rollback the change made by disable_torch_init."""
import torch

View file

@ -10,12 +10,12 @@ Based on the original script by 81300:
https://gist.github.com/81300/fe5b08bff1cba45296a829b9d6b0f303
'''
import argparse
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")

View file

@ -1,6 +1,5 @@
import torch
from transformers import BlipForConditionalGeneration
from transformers import BlipProcessor
from transformers import BlipForConditionalGeneration, BlipProcessor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")

View file

@ -7,13 +7,12 @@ from datetime import datetime
from io import BytesIO
from pathlib import Path
from PIL import Image
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html
from modules.text_generation import encode
from modules.text_generation import generate_reply
from modules.text_generation import get_max_prompt_length
from PIL import Image
from modules.text_generation import encode, generate_reply, get_max_prompt_length
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
import modules.bot_picture as bot_picture

View file

@ -1,6 +1,5 @@
import modules.shared as shared
import extensions
import modules.shared as shared
extension_state = {}
available_extensions = []

View file

@ -3,6 +3,7 @@
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
'''
import base64
import os
import re

View file

@ -4,23 +4,27 @@ import time
import zipfile
from pathlib import Path
import modules.shared as shared
import numpy as np
import torch
import transformers
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import modules.shared as shared
transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.flexgen:
from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, get_opt_config)
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
TorchDevice, TorchDisk, TorchMixedDevice,
get_opt_config)
if shared.args.deepspeed:
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
from transformers.deepspeed import (HfDeepSpeedConfig,
is_deepspeed_zero3_enabled)
from modules.deepspeed_parameters import generate_ds_config
# Distributed setup

View file

@ -4,9 +4,11 @@ This code was copied from
https://github.com/PygmalionAI/gradio-ui/
'''
import torch
import transformers
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,

View file

@ -1,16 +1,17 @@
import re
import time
import modules.shared as shared
import numpy as np
import torch
import transformers
from tqdm import tqdm
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html
from modules.html_generator import generate_basic_html
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
from tqdm import tqdm
def get_max_prompt_length(tokens):
max_length = 2048-tokens

View file

@ -14,12 +14,9 @@ import modules.chat as chat
import modules.extensions as extensions_module
import modules.shared as shared
import modules.ui as ui
from modules.extensions import extension_state
from modules.extensions import load_extensions
from modules.extensions import update_extensions_parameters
from modules.extensions import extension_state, load_extensions, update_extensions_parameters
from modules.html_generator import generate_chat_html
from modules.models import load_model
from modules.models import load_soft_prompt
from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply
if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: