Improve several log messages

This commit is contained in:
oobabooga 2023-12-19 20:54:32 -08:00
parent 23818dc098
commit 9992f7d8c0
7 changed files with 37 additions and 28 deletions

View file

@ -126,7 +126,7 @@ def load_quantized(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name) pt_path = find_quantized_model_file(model_name)
if not pt_path: if not pt_path:
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...") logger.error("Could not find the quantized model in .pt or .safetensors format. Exiting.")
exit() exit()
else: else:
logger.info(f"Found the following quantized model: {pt_path}") logger.info(f"Found the following quantized model: {pt_path}")

View file

@ -138,7 +138,7 @@ def add_lora_transformers(lora_names):
# Add a LoRA when another LoRA is already present # Add a LoRA when another LoRA is already present
if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys(): if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys():
logger.info(f"Adding the LoRA(s) named {added_set} to the model...") logger.info(f"Adding the LoRA(s) named {added_set} to the model")
for lora in added_set: for lora in added_set:
shared.model.load_adapter(get_lora_path(lora), lora) shared.model.load_adapter(get_lora_path(lora), lora)

View file

@ -31,7 +31,7 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
if name != 'api': if name != 'api':
logger.info(f'Loading the extension "{name}"...') logger.info(f'Loading the extension "{name}"')
try: try:
try: try:
exec(f"import extensions.{name}.script") exec(f"import extensions.{name}.script")

View file

@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers()
def load_model(model_name, loader=None): def load_model(model_name, loader=None):
logger.info(f"Loading {model_name}...") logger.info(f"Loading {model_name}")
t0 = time.time() t0 = time.time()
shared.is_seq2seq = False shared.is_seq2seq = False

View file

@ -204,22 +204,26 @@ for arg in sys.argv[1:]:
if hasattr(args, arg): if hasattr(args, arg):
provided_arguments.append(arg) provided_arguments.append(arg)
# Deprecation warnings
deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast'] deprecated_args = ['notebook', 'chat', 'no_stream', 'mul_mat_q', 'use_fast']
for k in deprecated_args:
if getattr(args, k):
logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
# Security warnings
if args.trust_remote_code: def do_cmd_flags_warnings():
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui: # Deprecation warnings
if args.share: for k in deprecated_args:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") if getattr(args, k):
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): logger.warning(f'The --{k} flag has been deprecated and will be removed soon. Please remove that flag.')
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user: # Security warnings
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.') if args.trust_remote_code:
logger.warning('trust_remote_code is enabled. This is dangerous.')
if 'COLAB_GPU' not in os.environ and not args.nowebui:
if args.share:
logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.")
if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)):
logger.warning("\nYou are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.")
if args.multi_user:
logger.warning('\nThe multi-user mode is highly experimental and should not be shared publicly.')
def fix_loader_name(name): def fix_loader_name(name):

View file

@ -249,7 +249,7 @@ def backup_adapter(input_folder):
adapter_file = Path(f"{input_folder}/adapter_model.bin") adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file(): if adapter_file.is_file():
logger.info("Backing up existing LoRA adapter...") logger.info("Backing up existing LoRA adapter")
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime) creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Prep the dataset, format, etc == # == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']: if raw_text_file not in ['None', '']:
train_template["template_type"] = "raw_text" train_template["template_type"] = "raw_text"
logger.info("Loading raw text file dataset...") logger.info("Loading raw text file dataset")
fullpath = clean_path('training/datasets', f'{raw_text_file}') fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath) fullpath = Path(fullpath)
if fullpath.is_dir(): if fullpath.is_dir():
@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
prompt = generate_prompt(data_point) prompt = generate_prompt(data_point)
return tokenize(prompt, add_eos_token) return tokenize(prompt, add_eos_token)
logger.info("Loading JSON datasets...") logger.info("Loading JSON datasets")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json')) data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30)) train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Start prepping the model itself == # == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...") logger.info("Getting model ready")
prepare_model_for_kbit_training(shared.model) prepare_model_for_kbit_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one # base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True shared.model_dirty_from_training = True
logger.info("Preparing for training...") logger.info("Preparing for training")
config = LoraConfig( config = LoraConfig(
r=lora_rank, r=lora_rank,
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
try: try:
logger.info("Creating LoRA model...") logger.info("Creating LoRA model")
lora_model = get_peft_model(shared.model, config) lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file(): if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logger.info("Loading existing LoRA data...") logger.info("Loading existing LoRA data")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True) state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
set_peft_model_state_dict(lora_model, state_dict_peft) set_peft_model_state_dict(lora_model, state_dict_peft)
except: except:
@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
json.dump(train_template, file, indent=2) json.dump(train_template, file, indent=2)
# == Main run and monitor loop == # == Main run and monitor loop ==
logger.info("Starting training...") logger.info("Starting training")
yield "Starting..." yield "Starting..."
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# Saving in the train thread might fail if an error occurs, so save here if so. # Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save: if not tracked.did_save:
logger.info("Training complete, saving...") logger.info("Training complete, saving")
lora_model.save_pretrained(lora_file_path) lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT: if WANT_INTERRUPT:

View file

@ -12,6 +12,7 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated')
warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict')
warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()')
with RequestBlocker(): with RequestBlocker():
import gradio as gr import gradio as gr
@ -54,6 +55,7 @@ from modules.models_settings import (
get_model_metadata, get_model_metadata,
update_model_parameters update_model_parameters
) )
from modules.shared import do_cmd_flags_warnings
from modules.utils import gradio from modules.utils import gradio
@ -170,6 +172,9 @@ def create_interface():
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Starting Text generation web UI")
do_cmd_flags_warnings()
# Load custom settings # Load custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -180,7 +185,7 @@ if __name__ == "__main__":
settings_file = Path('settings.json') settings_file = Path('settings.json')
if settings_file is not None: if settings_file is not None:
logger.info(f"Loading settings from {settings_file}...") logger.info(f"Loading settings from {settings_file}")
file_contents = open(settings_file, 'r', encoding='utf-8').read() file_contents = open(settings_file, 'r', encoding='utf-8').read()
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents) new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
shared.settings.update(new_settings) shared.settings.update(new_settings)