Make the code more like PEP8 for readability (#862)

This commit is contained in:
oobabooga 2023-04-07 00:15:45 -03:00 committed by GitHub
parent 848c4edfd5
commit ea6e77df72
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 302 additions and 165 deletions

View file

@ -17,6 +17,7 @@ def random_hash():
letters = string.ascii_lowercase + string.digits letters = string.ascii_lowercase + string.digits
return ''.join(random.choice(letters) for i in range(9)) return ''.join(random.choice(letters) for i in range(9))
async def run(context): async def run(context):
server = "127.0.0.1" server = "127.0.0.1"
params = { params = {
@ -69,6 +70,7 @@ async def run(context):
prompt = "What I would like to say is the following: " prompt = "What I would like to say is the following: "
async def get_result(): async def get_result():
async for response in run(prompt): async for response in run(prompt):
# Print intermediate steps # Print intermediate steps

View file

@ -17,6 +17,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
args = parser.parse_args() args = parser.parse_args()
def disable_torch_init(): def disable_torch_init():
""" """
Disable the redundant torch default initialization to accelerate model creation. Disable the redundant torch default initialization to accelerate model creation.
@ -31,12 +32,14 @@ def disable_torch_init():
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init(): def restore_torch_init():
"""Rollback the change made by disable_torch_init.""" """Rollback the change made by disable_torch_init."""
import torch import torch
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
if __name__ == '__main__': if __name__ == '__main__':
path = Path(args.MODEL) path = Path(args.MODEL)
model_name = path.name model_name = path.name

View file

@ -29,6 +29,7 @@ parser.add_argument('--clean', action='store_true', help='Does not resume the pr
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args() args = parser.parse_args()
def get_file(url, output_folder): def get_file(url, output_folder):
filename = Path(url.rsplit('/', 1)[1]) filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename output_path = output_folder / filename
@ -54,6 +55,7 @@ def get_file(url, output_folder):
t.update(len(data)) t.update(len(data))
f.write(data) f.write(data)
def sanitize_branch_name(branch_name): def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$") pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if pattern.match(branch_name): if pattern.match(branch_name):
@ -61,6 +63,7 @@ def sanitize_branch_name(branch_name):
else: else:
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
def select_model_from_default_options(): def select_model_from_default_options():
models = { models = {
"OPT 6.7B": ("facebook", "opt-6.7b", "main"), "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@ -106,6 +109,7 @@ EleutherAI/pythia-1.4b-deduped
return model, branch return model, branch
def get_download_links_from_huggingface(model, branch): def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co" base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor=" page = f"/api/models/{model}/tree/{branch}?cursor="
@ -172,9 +176,11 @@ def get_download_links_from_huggingface(model, branch):
return links, sha256, is_lora return links, sha256, is_lora
def download_files(file_list, output_folder, num_threads=8): def download_files(file_list, output_folder, num_threads=8):
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True) thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
if __name__ == '__main__': if __name__ == '__main__':
model = args.MODEL model = args.MODEL
branch = args.branch branch = args.branch

View file

@ -9,6 +9,7 @@ params = {
'port': 5000, 'port': 5000,
} }
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
if self.path == '/api/v1/model': if self.path == '/api/v1/model':
@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
prompt = body['prompt'] prompt = body['prompt']
prompt_lines = [l.strip() for l in prompt.split('\n')] prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048) max_context = body.get('max_context_length', 2048)
@ -95,5 +96,6 @@ def run_server():
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
server.serve_forever() server.serve_forever()
def setup(): def setup():
Thread(target=run_server, daemon=True).start() Thread(target=run_server, daemon=True).start()

View file

@ -5,6 +5,7 @@ params = {
"bias string": " *I am so happy*", "bias string": " *I am so happy*",
} }
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@ -13,6 +14,7 @@ def input_modifier(string):
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@ -20,6 +22,7 @@ def output_modifier(string):
return string return string
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
behavior. behavior.
""" """
if params['activate'] == True: if params['activate']:
return f'{string} {params["bias string"].strip()} ' return f'{string} {params["bias string"].strip()} '
else: else:
return string return string
def ui(): def ui():
# Gradio elements # Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate character bias') activate = gr.Checkbox(value=params['activate'], label='Activate character bias')

View file

@ -22,6 +22,8 @@ if not shared.args.no_stream:
raise ValueError raise ValueError
# Check if the API is valid and refresh the UI accordingly. # Check if the API is valid and refresh the UI accordingly.
def check_valid_api(): def check_valid_api():
global user, user_info, params global user, user_info, params
@ -29,7 +31,7 @@ def check_valid_api():
user = ElevenLabsUser(params['api_key']) user = ElevenLabsUser(params['api_key'])
user_info = user._get_subscription_data() user_info = user._get_subscription_data()
print('checking api') print('checking api')
if params['activate'] == False: if not params['activate']:
return gr.update(value='Disconnected') return gr.update(value='Disconnected')
elif user_info is None: elif user_info is None:
print('Incorrect API Key') print('Incorrect API Key')
@ -39,6 +41,8 @@ def check_valid_api():
return gr.update(value='Connected') return gr.update(value='Connected')
# Once the API is verified, get the available voices and update the dropdown list # Once the API is verified, get the available voices and update the dropdown list
def refresh_voices(): def refresh_voices():
global user, user_info global user, user_info
@ -51,11 +55,13 @@ def refresh_voices():
else: else:
return return
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)', '', string) return re.sub('\*[^\*]*?(\*|$)', '', string)
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@ -64,6 +70,7 @@ def input_modifier(string):
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@ -71,9 +78,9 @@ def output_modifier(string):
global params, wav_idx, user, user_info global params, wav_idx, user, user_info
if params['activate'] == False: if not params['activate']:
return string return string
elif user_info == None: elif user_info is None:
return string return string
string = remove_surrounded_chars(string) string = remove_surrounded_chars(string)
@ -94,6 +101,7 @@ def output_modifier(string):
wav_idx += 1 wav_idx += 1
return string return string
def ui(): def ui():
# Gradio elements # Gradio elements

View file

@ -7,6 +7,7 @@ params = {
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'} language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@ -15,6 +16,7 @@ def input_modifier(string):
return GoogleTranslator(source=params['language string'], target='en').translate(string) return GoogleTranslator(source=params['language string'], target='en').translate(string)
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@ -22,6 +24,7 @@ def output_modifier(string):
return GoogleTranslator(source='en', target=params['language string']).translate(string) return GoogleTranslator(source='en', target=params['language string']).translate(string)
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
return string return string
def ui(): def ui():
# Finding the language name from the language code to use as the default value # Finding the language name from the language code to use as the default value
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])] language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]

View file

@ -4,12 +4,14 @@ import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv") df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
def get_prompt_by_name(name): def get_prompt_by_name(name):
if name == 'None': if name == 'None':
return '' return ''
else: else:
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n') return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
def ui(): def ui():
if not shared.is_chat(): if not shared.is_chat():
choices = ['None'] + list(df['Prompt name']) choices = ['None'] + list(df['Prompt name'])

View file

@ -30,12 +30,15 @@ streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture picture_response = False # specifies if the next model response should appear as a picture
pic_id = 0 pic_id = 0
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)', '', string) return re.sub('\*[^\*]*?(\*|$)', '', string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@ -62,6 +65,8 @@ def input_modifier(string):
return string return string
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description):
global params, pic_id global params, pic_id
@ -101,6 +106,8 @@ def get_SD_pictures(description):
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging? # and replace it with 'text' for the purposes of logging?
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@ -130,6 +137,7 @@ def output_modifier(string):
shared.args.no_stream = streaming_state shared.args.no_stream = streaming_state
return image + "\n" + text return image + "\n" + text
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
return string return string
def force_pic(): def force_pic():
global picture_response global picture_response
picture_response = True picture_response = True
def ui(): def ui():
# Gradio elements # Gradio elements

View file

@ -17,11 +17,13 @@ input_hijack = {
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
def caption_image(raw_image): def caption_image(raw_image):
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
out = model.generate(**inputs, max_new_tokens=100) out = model.generate(**inputs, max_new_tokens=100)
return processor.decode(out[0], skip_special_tokens=True) return processor.decode(out[0], skip_special_tokens=True)
def generate_chat_picture(picture, name1, name2): def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">' visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
return text, visible_text return text, visible_text
def ui(): def ui():
picture_select = gr.Image(label='Send a picture', type='pil') picture_select = gr.Image(label='Send a picture', type='pil')

View file

@ -17,9 +17,11 @@ from quant import make_quant
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
config = AutoConfig.from_pretrained(model)
torch.nn.init.kaiming_uniform_ = noop torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop torch.nn.init.normal_ = noop
@ -64,6 +66,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
return model return model
def load_quantized(model_name): def load_quantized(model_name):
if not shared.args.model_type: if not shared.args.model_type:
# Try to determine model type from model name # Try to determine model type from model name

View file

@ -13,6 +13,7 @@ def reload_model():
clear_torch_cache() clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
def add_lora_to_model(lora_name): def add_lora_to_model(lora_name):
# If a LoRA had been previously loaded, or if we want # If a LoRA had been previously loaded, or if we want

View file

@ -54,6 +54,7 @@ class RWKVModel:
reply += token reply += token
yield reply yield reply
class RWKVTokenizer: class RWKVTokenizer:
def __init__(self): def __init__(self):
pass pass

View file

@ -28,6 +28,7 @@ def generate_reply_wrapper(string):
for i in generate_reply(params[0], generate_params): for i in generate_reply(params[0], generate_params):
yield i yield i
def create_apis(): def create_apis():
t1 = gr.Textbox(visible=False) t1 = gr.Textbox(visible=False)
t2 = gr.Textbox(visible=False) t2 = gr.Textbox(visible=False)

View file

@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
return True return True
return False return False
class Stream(transformers.StoppingCriteria): class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None): def __init__(self, callback_func=None):
self.callback_func = callback_func self.callback_func = callback_func
@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria):
self.callback_func(input_ids[0]) self.callback_func(input_ids[0])
return False return False
class Iteratorize: class Iteratorize:
""" """
@ -96,6 +98,7 @@ class Iteratorize:
self.stop_now = True self.stop_now = True
clear_torch_cache() clear_torch_cache()
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
if not shared.args.cpu: if not shared.args.cpu:

View file

@ -23,7 +23,6 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
# Finding the maximum prompt size # Finding the maximum prompt size
@ -68,6 +67,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
else: else:
return prompt return prompt
def extract_message_from_reply(reply, name1, name2, stop_at_newline): def extract_message_from_reply(reply, name1, name2, stop_at_newline):
next_character_found = False next_character_found = False
@ -98,6 +98,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
reply = fix_newlines(reply) reply = fix_newlines(reply)
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
if mode == 'instruct': if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{name1}", f"\n{name2}"]
@ -113,7 +114,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
visible_text = None visible_text = None
custom_generate_chat_prompt = None custom_generate_chat_prompt = None
for extension, _ in extensions_module.iterator(): for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False extension.input_hijack['state'] = False
text, visible_text = extension.input_hijack['value'] text, visible_text = extension.input_hijack['value']
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
@ -167,6 +168,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if mode == 'instruct': if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{name1}", f"\n{name2}"]
@ -197,10 +199,12 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
yield reply yield reply
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
yield chat_html_wrapper(history, name1, name2, mode) yield chat_html_wrapper(history, name1, name2, mode)
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -213,6 +217,7 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
shared.history['visible'][-1] = [last_visible[0], history[-1][1]] shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def remove_last_message(name1, name2, mode): def remove_last_message(name1, name2, mode):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop() last = shared.history['visible'].pop()
@ -222,12 +227,14 @@ def remove_last_message(name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0] return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
def send_last_reply_to_input(): def send_last_reply_to_input():
if len(shared.history['internal']) > 0: if len(shared.history['internal']) > 0:
return shared.history['internal'][-1][1] return shared.history['internal'][-1][1]
else: else:
return '' return ''
def replace_last_reply(text, name1, name2, mode): def replace_last_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0: if len(shared.history['visible']) > 0:
shared.history['visible'][-1][1] = text shared.history['visible'][-1][1] = text
@ -235,9 +242,11 @@ def replace_last_reply(text, name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def clear_html(): def clear_html():
return chat_html_wrapper([], "", "") return chat_html_wrapper([], "", "")
def clear_chat_log(name1, name2, greeting, mode): def clear_chat_log(name1, name2, greeting, mode):
shared.history['visible'] = [] shared.history['visible'] = []
shared.history['internal'] = [] shared.history['internal'] = []
@ -248,9 +257,11 @@ def clear_chat_log(name1, name2, greeting, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def redraw_html(name1, name2, mode): def redraw_html(name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def tokenize_dialogue(dialogue, name1, name2, mode): def tokenize_dialogue(dialogue, name1, name2, mode):
history = [] history = []
@ -288,6 +299,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
return history return history
def save_history(timestamp=True): def save_history(timestamp=True):
if timestamp: if timestamp:
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
@ -299,6 +311,7 @@ def save_history(timestamp=True):
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}') return Path(f'logs/{fname}')
def load_history(file, name1, name2): def load_history(file, name1, name2):
file = file.decode('utf-8') file = file.decode('utf-8')
try: try:
@ -323,10 +336,12 @@ def load_history(file, name1, name2):
shared.history['internal'] = tokenize_dialogue(file, name1, name2) shared.history['internal'] = tokenize_dialogue(file, name1, name2)
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def replace_character_names(text, name1, name2): def replace_character_names(text, name1, name2):
text = text.replace('{{user}}', name1).replace('{{char}}', name2) text = text.replace('{{user}}', name1).replace('{{char}}', name2)
return text.replace('<USER>', name1).replace('<BOT>', name2) return text.replace('<USER>', name1).replace('<BOT>', name2)
def build_pygmalion_style_context(data): def build_pygmalion_style_context(data):
context = "" context = ""
if 'char_persona' in data and data['char_persona'] != '': if 'char_persona' in data and data['char_persona'] != '':
@ -336,6 +351,7 @@ def build_pygmalion_style_context(data):
context = f"{context.strip()}\n<START>\n" context = f"{context.strip()}\n<START>\n"
return context return context
def generate_pfp_cache(character): def generate_pfp_cache(character):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
@ -348,6 +364,7 @@ def generate_pfp_cache(character):
return img return img
return None return None
def load_character(character, name1, name2, mode): def load_character(character, name1, name2, mode):
shared.character = character shared.character = character
shared.history['internal'] = [] shared.history['internal'] = []
@ -404,9 +421,11 @@ def load_character(character, name1, name2, mode):
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
def load_default_history(name1, name2): def load_default_history(name1, name2):
load_character("None", name1, name2, "chat") load_character("None", name1, name2, "chat")
def upload_character(json_file, img, tavern=False): def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8') json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
data = json.loads(json_file) data = json.loads(json_file)
@ -425,6 +444,7 @@ def upload_character(json_file, img, tavern=False):
print(f'New character saved to "characters/{outfile_name}.json".') print(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name return outfile_name
def upload_tavern_character(img, name1, name2): def upload_tavern_character(img, name1, name2):
_img = Image.open(io.BytesIO(img)) _img = Image.open(io.BytesIO(img))
_img.getexif() _img.getexif()
@ -433,12 +453,13 @@ def upload_tavern_character(img, name1, name2):
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
return upload_character(json.dumps(_json), img, tavern=True) return upload_character(json.dumps(_json), img, tavern=True)
def upload_your_profile_picture(img, name1, name2, mode): def upload_your_profile_picture(img, name1, name2, mode):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
cache_folder.mkdir() cache_folder.mkdir()
if img == None: if img is None:
if Path("cache/pfp_me.png").exists(): if Path("cache/pfp_me.png").exists():
Path("cache/pfp_me.png").unlink() Path("cache/pfp_me.png").unlink()
else: else:

View file

@ -9,6 +9,7 @@ state = {}
available_extensions = [] available_extensions = []
setup_called = set() setup_called = set()
def load_extensions(): def load_extensions():
global state global state
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
@ -23,12 +24,16 @@ def load_extensions():
traceback.print_exc() traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line # This iterator returns the extensions in the order specified in the command-line
def iterator(): def iterator():
for name in sorted(state, key=lambda x: state[x][1]): for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0] == True: if state[name][0] == True:
yield eval(f"extensions.{name}.script"), name yield eval(f"extensions.{name}.script"), name
# Extension functions that map string -> string # Extension functions that map string -> string
def apply_extensions(text, typ): def apply_extensions(text, typ):
for extension, _ in iterator(): for extension, _ in iterator():
if typ == "input" and hasattr(extension, "input_modifier"): if typ == "input" and hasattr(extension, "input_modifier"):
@ -39,6 +44,7 @@ def apply_extensions(text, typ):
text = extension.bot_prefix_modifier(text) text = extension.bot_prefix_modifier(text)
return text return text
def create_extensions_block(): def create_extensions_block():
global setup_called global setup_called

View file

@ -24,6 +24,7 @@ with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f: with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
instruct_css = f.read() instruct_css = f.read()
def fix_newlines(string): def fix_newlines(string):
string = string.replace('\n', '\n\n') string = string.replace('\n', '\n\n')
string = re.sub(r"\n{3,}", "\n\n", string) string = re.sub(r"\n{3,}", "\n\n", string)
@ -31,6 +32,8 @@ def fix_newlines(string):
return string return string
# This could probably be generalized and improved # This could probably be generalized and improved
def convert_to_markdown(string): def convert_to_markdown(string):
string = string.replace('\\begin{code}', '```') string = string.replace('\\begin{code}', '```')
string = string.replace('\\end{code}', '```') string = string.replace('\\end{code}', '```')
@ -40,11 +43,13 @@ def convert_to_markdown(string):
string = fix_newlines(string) string = fix_newlines(string)
return markdown.markdown(string, extensions=['fenced_code']) return markdown.markdown(string, extensions=['fenced_code'])
def generate_basic_html(string): def generate_basic_html(string):
string = convert_to_markdown(string) string = convert_to_markdown(string)
string = f'<style>{readable_css}</style><div class="container">{string}</div>' string = f'<style>{readable_css}</style><div class="container">{string}</div>'
return string return string
def process_post(post, c): def process_post(post, c):
t = post.split('\n') t = post.split('\n')
number = t[0].split(' ')[1] number = t[0].split(' ')[1]
@ -59,6 +64,7 @@ def process_post(post, c):
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}' src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
return src return src
def generate_4chan_html(f): def generate_4chan_html(f):
posts = [] posts = []
post = '' post = ''
@ -98,6 +104,7 @@ def generate_4chan_html(f):
return output return output
def make_thumbnail(image): def make_thumbnail(image):
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS) image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
if image.size[1] > 470: if image.size[1] > 470:
@ -105,6 +112,7 @@ def make_thumbnail(image):
return image return image
def get_image_cache(path): def get_image_cache(path):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
@ -119,6 +127,7 @@ def get_image_cache(path):
return image_cache[path][1] return image_cache[path][1]
def generate_instruct_html(history): def generate_instruct_html(history):
output = f'<style>{instruct_css}</style><div class="chat" id="chat">' output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
for i, _row in enumerate(history[::-1]): for i, _row in enumerate(history[::-1]):
@ -151,6 +160,7 @@ def generate_instruct_html(history):
return output return output
def generate_cai_chat_html(history, name1, name2, reset_cache=False): def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f'<style>{cai_css}</style><div class="chat" id="chat">' output = f'<style>{cai_css}</style><div class="chat" id="chat">'
@ -200,9 +210,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output += "</div>" output += "</div>"
return output return output
def generate_chat_html(history, name1, name2): def generate_chat_html(history, name1, name2):
return generate_cai_chat_html(history, name1, name2) return generate_cai_chat_html(history, name1, name2)
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False): def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
if mode == "cai-chat": if mode == "cai-chat":
return generate_cai_chat_html(history, name1, name2, reset_cache) return generate_cai_chat_html(history, name1, name2, reset_cache)

View file

@ -6,8 +6,6 @@ Documentation:
https://abetlen.github.io/llama-cpp-python/ https://abetlen.github.io/llama-cpp-python/
''' '''
import multiprocessing
from llama_cpp import Llama from llama_cpp import Llama
from modules import shared from modules import shared

View file

@ -181,6 +181,7 @@ def load_model(model_name):
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
def load_soft_prompt(name): def load_soft_prompt(name):
if name == 'None': if name == 'None':
shared.soft_prompt = False shared.soft_prompt = False

View file

@ -61,6 +61,7 @@ settings = {
} }
} }
def str2bool(v): def str2bool(v):
if isinstance(v, bool): if isinstance(v, bool):
return v return v
@ -71,6 +72,7 @@ def str2bool(v):
else: else:
raise argparse.ArgumentTypeError('Boolean value expected.') raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
# Basic settings # Basic settings
@ -145,5 +147,6 @@ if args.cai_chat:
print("Warning: --cai-chat is deprecated. Use --chat instead.") print("Warning: --cai-chat is deprecated. Use --chat instead.")
args.chat = True args.chat = True
def is_chat(): def is_chat():
return args.chat return args.chat

View file

@ -21,6 +21,7 @@ def get_max_prompt_length(tokens):
max_length -= shared.soft_prompt_tensor.shape[1] max_length -= shared.soft_prompt_tensor.shape[1]
return max_length return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True): def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
@ -44,6 +45,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
else: else:
return input_ids.cuda() return input_ids.cuda()
def decode(output_ids): def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|> # Open Assistant relies on special tokens like <|endoftext|>
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()): if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
@ -53,6 +55,7 @@ def decode(output_ids):
reply = reply.replace(r'<|endoftext|>', '') reply = reply.replace(r'<|endoftext|>', '')
return reply return reply
def generate_softprompt_input_tensors(input_ids): def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids) inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1) inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
@ -61,6 +64,8 @@ def generate_softprompt_input_tensors(input_ids):
return inputs_embeds, filler_input_ids return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
@ -69,6 +74,8 @@ def fix_gpt4chan(s):
return s return s
# Fix the LaTeX equations in galactica # Fix the LaTeX equations in galactica
def fix_galactica(s): def fix_galactica(s):
s = s.replace(r'\[', r'$') s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$') s = s.replace(r'\]', r'$')
@ -79,6 +86,7 @@ def fix_galactica(s):
s = re.sub(r"\n{3,}", "\n\n", s) s = re.sub(r"\n{3,}", "\n\n", s)
return s return s
def formatted_outputs(reply, model_name): def formatted_outputs(reply, model_name):
if not shared.is_chat(): if not shared.is_chat():
if 'galactica' in model_name.lower(): if 'galactica' in model_name.lower():
@ -92,20 +100,24 @@ def formatted_outputs(reply, model_name):
else: else:
return reply return reply
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
if not shared.args.cpu: if not shared.args.cpu:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def set_manual_seed(seed): def set_manual_seed(seed):
if seed != -1: if seed != -1:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
set_manual_seed(generate_state['seed']) set_manual_seed(generate_state['seed'])

View file

@ -19,9 +19,11 @@ CURRENT_STEPS = 0
MAX_STEPS = 0 MAX_STEPS = 0
CURRENT_GRADIENT_ACCUM = 1 CURRENT_GRADIENT_ACCUM = 1
def get_dataset(path: str, ext: str): def get_dataset(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower) return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
def create_train_interface(): def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'): with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file") lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
@ -67,10 +69,12 @@ def create_train_interface():
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output]) cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False) stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_interrupt(): def do_interrupt():
global WANT_INTERRUPT global WANT_INTERRUPT
WANT_INTERRUPT = True WANT_INTERRUPT = True
class Callbacks(transformers.TrainerCallback): class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS, MAX_STEPS global CURRENT_STEPS, MAX_STEPS
@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
if WANT_INTERRUPT: if WANT_INTERRUPT:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS global CURRENT_STEPS
CURRENT_STEPS += 1 CURRENT_STEPS += 1
@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def clean_path(base_path: str, path: str): def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str):
return path return path
return f'{Path(base_path).absolute()}/{path}' return f'{Path(base_path).absolute()}/{path}'
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float, def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int): cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
print("Training complete!") print("Training complete!")
yield f"Done! LoRA saved to `{lora_name}`" yield f"Done! LoRA saved to `{lora_name}`"
def split_chunks(arr, step): def split_chunks(arr, step):
for i in range(0, len(arr), step): for i in range(0, len(arr), step):
yield arr[i:i + step] yield arr[i:i + step]
def cut_chunk_for_newline(chunk: str, max_length: int): def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk: if '\n' not in chunk:
return chunk return chunk
@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int):
chunk = chunk[:last_newline] chunk = chunk[:last_newline]
return chunk return chunk
def format_time(seconds: float): def format_time(seconds: float):
if seconds < 120: if seconds < 120:
return f"`{seconds:.0f}` seconds" return f"`{seconds:.0f}` seconds"

View file

@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
chat_js = f.read() chat_js = f.read()
class ToolButton(gr.Button, gr.components.FormComponent): class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms""" """Small button with single emoji as text, fits inside gradio forms"""
@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
def get_block_name(self): def get_block_name(self):
return "button" return "button"
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh(): def refresh():
refresh_method() refresh_method()

View file

@ -34,15 +34,18 @@ if settings_file is not None:
for item in new_settings: for item in new_settings:
shared.settings[item] = new_settings[item] shared.settings[item] = new_settings[item]
def get_available_models(): def get_available_models():
if shared.args.flexgen: if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else: else:
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def get_available_presets(): def get_available_presets():
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower) return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
def get_available_prompts(): def get_available_prompts():
prompts = [] prompts = []
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True) prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
@ -50,10 +53,12 @@ def get_available_prompts():
prompts += ['None'] prompts += ['None']
return prompts return prompts
def get_available_characters(): def get_available_characters():
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower) return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
def get_available_instruction_templates(): def get_available_instruction_templates():
path = "characters/instruction-following" path = "characters/instruction-following"
paths = [] paths = []
@ -61,19 +66,24 @@ def get_available_instruction_templates():
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower) return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
def get_available_extensions(): def get_available_extensions():
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
def get_available_softprompts(): def get_available_softprompts():
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower) return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
def get_available_loras(): def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def unload_model(): def unload_model():
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
clear_torch_cache() clear_torch_cache()
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model
@ -84,10 +94,12 @@ def load_model_wrapper(selected_model):
return selected_model return selected_model
def load_lora_wrapper(selected_lora): def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora) add_lora_to_model(selected_lora)
return selected_lora return selected_lora
def load_preset_values(preset_menu, state, return_dict=False): def load_preset_values(preset_menu, state, return_dict=False):
generate_params = { generate_params = {
'do_sample': True, 'do_sample': True,
@ -118,6 +130,7 @@ def load_preset_values(preset_menu, state, return_dict=False):
state.update(generate_params) state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
def upload_soft_prompt(file): def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf: with zipfile.ZipFile(io.BytesIO(file)) as zf:
zf.extract('meta.json') zf.extract('meta.json')
@ -130,12 +143,14 @@ def upload_soft_prompt(file):
return name return name
def save_prompt(text): def save_prompt(text):
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt" fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text) f.write(text)
return f"Saved to prompts/{fname}" return f"Saved to prompts/{fname}"
def load_prompt(fname): def load_prompt(fname):
if fname in ['None', '']: if fname in ['None', '']:
return '' return ''
@ -146,6 +161,7 @@ def load_prompt(fname):
text = text[:-1] text = text[:-1]
return text return text
def create_prompt_menus(): def create_prompt_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -161,6 +177,7 @@ def create_prompt_menus():
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def create_model_menus(): def create_model_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -175,6 +192,7 @@ def create_model_menus():
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']: for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
@ -209,7 +227,6 @@ def create_settings_menus(default_preset):
with gr.Box(): with gr.Box():
gr.Markdown('Contrastive search') gr.Markdown('Contrastive search')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
with gr.Box(): with gr.Box():
gr.Markdown('Beam search (uses a lot of VRAM)') gr.Markdown('Beam search (uses a lot of VRAM)')
with gr.Row(): with gr.Row():
@ -219,7 +236,6 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
with gr.Accordion('Soft prompt', open=False): with gr.Accordion('Soft prompt', open=False):
with gr.Row(): with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
@ -233,6 +249,7 @@ def create_settings_menus(default_preset):
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
def set_interface_arguments(interface_mode, extensions, bool_active): def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"] modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args) cmd_list = vars(shared.args)
@ -251,6 +268,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
shared.need_restart = True shared.need_restart = True
available_models = get_available_models() available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
available_characters = get_available_characters() available_characters = get_available_characters()
@ -299,8 +317,8 @@ else:
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]) default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
title = 'Text generation web UI' title = 'Text generation web UI'
def create_interface():
def create_interface():
gen_events = [] gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0: if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions() extensions_module.load_extensions()
@ -562,6 +580,7 @@ def create_interface():
else: else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
create_interface() create_interface()
while True: while True: