Better warning messages

This commit is contained in:
oobabooga 2023-05-03 21:43:17 -03:00
parent 0a48b29cd8
commit 95d04d6a8d
13 changed files with 194 additions and 83 deletions

View file

@ -1,4 +1,5 @@
import inspect
import logging
import re
import sys
from pathlib import Path
@ -71,7 +72,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint), strict=False)
@ -90,8 +90,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
quant.autotune_warmup_fused(model)
model.seqlen = 2048
print('Done.')
return model
@ -119,11 +117,13 @@ def find_quantized_model_file(model_name):
if len(found_pts) > 0:
if len(found_pts) > 1:
print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.')
logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
pt_path = found_pts[-1]
elif len(found_safetensors) > 0:
if len(found_pts) > 1:
print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.')
logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
pt_path = found_safetensors[-1]
return pt_path
@ -142,8 +142,7 @@ def load_quantized(model_name):
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
model_type = 'gptj'
else:
print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument")
logging.error("Can't determine model type from model name. Please specify it manually using --model_type argument")
exit()
else:
model_type = shared.args.model_type.lower()
@ -153,20 +152,21 @@ def load_quantized(model_name):
load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer:
print("Warning: ignoring --pre_layer because it only works for llama model type.")
logging.warning("Ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
logging.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit()
# Find the quantized model weights file (.pt/.safetensors)
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name)
if not pt_path:
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
logging.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit()
else:
print(f"Found the following quantized model: {pt_path}")
logging.info(f"Found the following quantized model: {pt_path}")
# qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer:
@ -188,7 +188,7 @@ def load_quantized(model_name):
max_memory = accelerate.utils.get_balanced_memory(model)
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
print("Using the following device map for the quantized model:", device_map)
logging.info("Using the following device map for the quantized model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)

View file

@ -1,3 +1,4 @@
import logging
from pathlib import Path
import torch
@ -18,7 +19,7 @@ def add_lora_to_model(lora_names):
# Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0:
print(f"Adding the LoRA(s) named {added_set} to the model...")
logging.info(f"Adding the LoRA(s) named {added_set} to the model...")
for lora in added_set:
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
@ -29,7 +30,7 @@ def add_lora_to_model(lora_names):
shared.model.disable_adapter()
if len(lora_names) > 0:
print("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
logging.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
params = {}
if not shared.args.cpu:
params['dtype'] = shared.model.dtype

View file

@ -3,6 +3,7 @@ import base64
import copy
import io
import json
import logging
import re
from datetime import datetime
from pathlib import Path
@ -138,7 +139,7 @@ def extract_message_from_reply(reply, state):
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.")
logging.error("No model is loaded! Select one in the Model tab.")
yield shared.history['visible']
return
@ -216,7 +217,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
def impersonate_wrapper(text, state):
if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.")
logging.error("No model is loaded! Select one in the Model tab.")
yield ''
return
@ -523,7 +524,7 @@ def upload_character(json_file, img, tavern=False):
img = Image.open(io.BytesIO(img))
img.save(Path(f'characters/{outfile_name}.png'))
print(f'New character saved to "characters/{outfile_name}.json".')
logging.info(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name
@ -547,6 +548,6 @@ def upload_your_profile_picture(img, name1, name2, mode):
else:
img = make_thumbnail(img)
img.save(Path('cache/pfp_me.png'))
print('Profile picture saved to "cache/pfp_me.png"')
logging.info('Profile picture saved to "cache/pfp_me.png"')
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)

View file

@ -1,3 +1,4 @@
import logging
import traceback
from functools import partial
@ -28,7 +29,7 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
if name != 'api':
print(f'Loading the extension "{name}"... ', end='')
logging.info(f'Loading the extension "{name}"...')
try:
exec(f"import extensions.{name}.script")
extension = getattr(extensions, name).script
@ -38,12 +39,8 @@ def load_extensions():
extension.setup()
state[name] = [True, i]
if name != 'api':
print('Ok.')
except:
if name != 'api':
print('Fail.')
logging.error('Failed to load the extension "{name}".')
traceback.print_exc()

View file

@ -1,29 +1,28 @@
import logging
import math
import sys
from typing import Optional, Tuple
import torch
import torch.nn as nn
import transformers.models.llama.modeling_llama
from typing import Optional
from typing import Tuple
import modules.shared as shared
if shared.args.xformers:
try:
import xformers.ops
except Exception:
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
def hijack_llama_attention():
if shared.args.xformers:
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
print("Replaced attention with xformers_attention")
logging.info("Replaced attention with xformers_attention")
elif shared.args.sdp_attention:
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
print("Replaced attention with sdp_attention")
logging.info("Replaced attention with sdp_attention")
def xformers_forward(
@ -55,16 +54,14 @@ def xformers_forward(
past_key_value = (key_states, value_states) if use_cache else None
#We only apply xformers optimizations if we don't need to output the whole attention matrix
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
dtype = query_states.dtype
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
@ -102,9 +99,7 @@ def xformers_forward(
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
@ -137,7 +132,7 @@ def sdp_attention_forward(
past_key_value = (key_states, value_states) if use_cache else None
#We only apply sdp attention if we don't need to output the whole attention matrix
# We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
attn_weights = None

109
modules/logging_colors.py Normal file
View file

@ -0,0 +1,109 @@
#!/usr/bin/env python
# encoding: utf-8
import logging
# now we patch Python code to add color support to logging.StreamHandler
def add_coloring_to_emit_windows(fn):
# add methods we need to the class
def _out_handle(self):
import ctypes
return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
out_handle = property(_out_handle)
def _set_color(self, code):
import ctypes
# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
setattr(logging.StreamHandler, '_set_color', _set_color)
def new(*args):
FOREGROUND_BLUE = 0x0001 # text color contains blue.
FOREGROUND_GREEN = 0x0002 # text color contains green.
FOREGROUND_RED = 0x0004 # text color contains red.
FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
# winbase.h
# STD_INPUT_HANDLE = -10
# STD_OUTPUT_HANDLE = -11
# STD_ERROR_HANDLE = -12
# wincon.h
# FOREGROUND_BLACK = 0x0000
FOREGROUND_BLUE = 0x0001
FOREGROUND_GREEN = 0x0002
# FOREGROUND_CYAN = 0x0003
FOREGROUND_RED = 0x0004
FOREGROUND_MAGENTA = 0x0005
FOREGROUND_YELLOW = 0x0006
# FOREGROUND_GREY = 0x0007
FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
# BACKGROUND_BLACK = 0x0000
# BACKGROUND_BLUE = 0x0010
# BACKGROUND_GREEN = 0x0020
# BACKGROUND_CYAN = 0x0030
# BACKGROUND_RED = 0x0040
# BACKGROUND_MAGENTA = 0x0050
BACKGROUND_YELLOW = 0x0060
# BACKGROUND_GREY = 0x0070
BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
levelno = args[1].levelno
if (levelno >= 50):
color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
elif (levelno >= 40):
color = FOREGROUND_RED | FOREGROUND_INTENSITY
elif (levelno >= 30):
color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
elif (levelno >= 20):
color = FOREGROUND_GREEN
elif (levelno >= 10):
color = FOREGROUND_MAGENTA
else:
color = FOREGROUND_WHITE
args[0]._set_color(color)
ret = fn(*args)
args[0]._set_color(FOREGROUND_WHITE)
# print "after"
return ret
return new
def add_coloring_to_emit_ansi(fn):
# add methods we need to the class
def new(*args):
levelno = args[1].levelno
if (levelno >= 50):
color = '\x1b[31m' # red
elif (levelno >= 40):
color = '\x1b[31m' # red
elif (levelno >= 30):
color = '\x1b[33m' # yellow
elif (levelno >= 20):
color = '\x1b[32m' # green
elif (levelno >= 10):
color = '\x1b[35m' # pink
else:
color = '\x1b[0m' # normal
args[1].msg = color + args[1].msg + '\x1b[0m' # normal
# print "after"
return fn(*args)
return new
import platform
if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
else:
# all non-Windows platforms are supporting ANSI escapes so we use them
logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
# log = logging.getLogger()
# log.addFilter(log_filter())
# //hdlr = logging.StreamHandler()
# //hdlr.setFormatter(formatter())

View file

@ -1,5 +1,6 @@
import gc
import json
import logging
import os
import re
import time
@ -65,7 +66,7 @@ def find_model_type(model_name):
def load_model(model_name):
print(f"Loading {model_name}...")
logging.info(f"Loading {model_name}...")
t0 = time.time()
shared.model_type = find_model_type(model_name)
@ -116,7 +117,7 @@ def load_model(model_name):
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
model.module.eval() # Inference
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
logging.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
# RMKV model (not on HuggingFace)
elif shared.model_type == 'rwkv':
@ -137,7 +138,7 @@ def load_model(model_name):
else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
print(f"llama.cpp weights detected: {model_file}\n")
logging.info(f"llama.cpp weights detected: {model_file}\n")
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
return model, tokenizer
@ -146,7 +147,7 @@ def load_model(model_name):
# Monkey patch
if shared.args.monkey_patch:
print("Warning: applying the monkey patch for using LoRAs in 4-bit mode.\nIt may cause undefined behavior outside its intended scope.")
logging.warning("Warning: applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
from modules.monkey_patch_gptq_lora import load_model_llama
model, _ = load_model_llama(model_name)
@ -161,7 +162,7 @@ def load_model(model_name):
else:
params = {"low_cpu_mem_usage": True}
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
logging.warning("Warning: torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
shared.args.cpu = True
if shared.args.cpu:
@ -184,6 +185,7 @@ def load_model(model_name):
max_memory = {}
for i in range(len(memory_map)):
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
max_memory['cpu'] = max_cpu_memory
params['max_memory'] = max_memory
elif shared.args.auto_devices:
@ -191,9 +193,9 @@ def load_model(model_name):
suggestion = round((total_mem - 1000) / 1000) * 1000
if total_mem - suggestion < 800:
suggestion -= 1000
suggestion = int(round(suggestion / 1000))
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
suggestion = int(round(suggestion / 1000))
logging.warning(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
params['max_memory'] = max_memory
@ -201,11 +203,11 @@ def load_model(model_name):
params["offload_folder"] = shared.args.disk_cache_dir
checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights():
model = LoaderClass.from_config(config)
model.tie_weights()
params['device_map'] = infer_auto_device_map(
model,
@ -230,7 +232,7 @@ def load_model(model_name):
if shared.model_type != 'llava':
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists():
print(f"Loading the universal LLaMA tokenizer from {p}...")
logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
break
@ -247,7 +249,7 @@ def load_model(model_name):
else:
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=trust_remote_code)
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer
@ -276,20 +278,20 @@ def load_soft_prompt(name):
zf.extract('tensor.npy')
zf.extract('meta.json')
j = json.loads(open('meta.json', 'r').read())
print(f"\nLoading the softprompt \"{name}\".")
logging.info(f"\nLoading the softprompt \"{name}\".")
for field in j:
if field != 'name':
if type(j[field]) is list:
print(f"{field}: {', '.join(j[field])}")
logging.info(f"{field}: {', '.join(j[field])}")
else:
print(f"{field}: {j[field]}")
print()
logging.info(f"{field}: {j[field]}")
logging.info()
tensor = np.load('tensor.npy')
Path('tensor.npy').unlink()
Path('meta.json').unlink()
tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
shared.soft_prompt = True
shared.soft_prompt_tensor = tensor

View file

@ -17,6 +17,7 @@ from modules.GPTQ_loader import find_quantized_model_file
replace_peft_model_with_gptq_lora_model()
def load_model_llama(model_name):
config_path = str(Path(f'{shared.args.model_dir}/{model_name}'))
model_path = str(find_quantized_model_file(model_name))

View file

@ -1,4 +1,5 @@
import argparse
import logging
from pathlib import Path
import yaml
@ -170,19 +171,19 @@ args_defaults = parser.parse_args([])
deprecated_dict = {}
for k in deprecated_dict:
if getattr(args, k) != deprecated_dict[k][1]:
print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.\n")
logging.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
setattr(args, deprecated_dict[k][0], getattr(args, k))
# Deprecation warnings for parameters that have been removed
if args.cai_chat:
print("Warning: --cai-chat is deprecated. Use --chat instead.\n")
logging.warning("--cai-chat is deprecated. Use --chat instead.")
args.chat = True
# Security warnings
if args.trust_remote_code:
print("Warning: trust_remote_code is enabled. This is dangerous.\n")
logging.warning("trust_remote_code is enabled. This is dangerous.")
if args.share:
print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n")
logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
# Activating the API extension
if args.api or args.public_api:

View file

@ -1,4 +1,5 @@
import ast
import logging
import random
import re
import time
@ -175,7 +176,7 @@ def get_generate_params(state):
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.")
logging.error("No model is loaded! Select one in the Model tab.")
yield formatted_outputs(question, shared.model_name)
return

View file

@ -1,4 +1,5 @@
import json
import logging
import math
import sys
import threading
@ -40,7 +41,6 @@ WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"]
def get_datasets(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
@ -123,7 +123,7 @@ def create_train_interface():
stop_evaluation = gr.Button("Interrupt")
with gr.Column():
evaluation_log = gr.Markdown(value = '')
evaluation_log = gr.Markdown(value='')
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
save_comments = gr.Button('Save comments')
@ -220,13 +220,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if model_type == "PeftModelForCausalLM":
if len(shared.args.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.")
logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5)
if shared.args.wbits > 0 and not shared.args.monkey_patch:
@ -235,7 +236,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
time.sleep(2) # Give it a moment for the message to show in UI before continuing
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
@ -255,7 +256,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
print("Loading raw text file dataset...")
logging.info("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read()
@ -299,7 +300,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
prompt = generate_prompt(data_point)
return tokenize(prompt)
print("Loading JSON datasets...")
logging.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt)
@ -311,10 +312,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
print("Getting model ready...")
logging.info("Getting model ready...")
prepare_model_for_int8_training(shared.model)
print("Prepping for training...")
logging.info("Prepping for training...")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
@ -325,10 +326,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
)
try:
print("Creating LoRA model...")
logging.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
print("Loading existing LoRA data...")
logging.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
set_peft_model_state_dict(lora_model, state_dict_peft)
except:
@ -406,7 +407,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
json.dump({x: vars[x] for x in PARAMETERS}, file)
# == Main run and monitor loop ==
print("Starting training...")
logging.info("Starting training...")
yield "Starting..."
if WANT_INTERRUPT:
yield "Interrupted before start."
@ -416,7 +417,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
print("LoRA training run is completed and saved.")
logging.info("LoRA training run is completed and saved.")
tracked.did_save = True
thread = threading.Thread(target=threaded_run)
@ -448,14 +449,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save:
print("Training complete, saving...")
logging.info("Training complete, saving...")
lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT:
print("Training interrupted.")
logging.info("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
else:
print("Training complete!")
logging.info("Training complete!")
yield f"Done! LoRA saved to `{lora_file_path}`"

View file

@ -25,6 +25,7 @@ theme = gr.themes.Default(
background_fill_secondary='#eaeaea'
)
def list_model_elements():
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
for i in range(torch.cuda.device_count()):

View file

@ -1,14 +1,17 @@
import logging
import os
import requests
import warnings
import modules.logging_colors
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
# This is a hack to prevent Gradio from phoning home when it gets imported
def my_get(url, **kwargs):
print('Gradio HTTP request redirected to localhost :)')
logging.info('Gradio HTTP request redirected to localhost :)')
kwargs.setdefault('allow_redirects', True)
return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
@ -17,9 +20,8 @@ requests.get = my_get
import gradio as gr
requests.get = original_get
# This fixes LaTeX rendering on some systems
import matplotlib
matplotlib.use('Agg')
matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
import importlib
import io
@ -39,7 +41,6 @@ import psutil
import torch
import yaml
from PIL import Image
import modules.extensions as extensions_module
from modules import chat, shared, training, ui
from modules.html_generator import chat_html_wrapper
@ -860,7 +861,7 @@ if __name__ == "__main__":
elif Path('settings.json').exists():
settings_file = Path('settings.json')
if settings_file is not None:
print(f"Loading settings from {settings_file}...")
logging.info(f"Loading settings from {settings_file}...")
new_settings = json.loads(open(settings_file, 'r').read())
for item in new_settings:
shared.settings[item] = new_settings[item]
@ -891,7 +892,7 @@ if __name__ == "__main__":
# Select the model from a command-line menu
elif shared.args.model_menu:
if len(available_models) == 0:
print('No models are available! Please download at least one.')
logging.error('No models are available! Please download at least one.')
sys.exit(0)
else:
print('The following models are available:\n')