Minor logging improvements

This commit is contained in:
oobabooga 2024-02-06 08:22:08 -08:00
parent 3add2376cd
commit 4e34ae0587
2 changed files with 6 additions and 6 deletions

View file

@ -54,7 +54,7 @@ sampler_hijack.hijack_samplers()
def load_model(model_name, loader=None): def load_model(model_name, loader=None):
logger.info(f"Loading {model_name}") logger.info(f"Loading \"{model_name}\"")
t0 = time.time() t0 = time.time()
shared.is_seq2seq = False shared.is_seq2seq = False
@ -246,7 +246,7 @@ def llamacpp_loader(model_name):
else: else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0] model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0]
logger.info(f"llama.cpp weights detected: {model_file}") logger.info(f"llama.cpp weights detected: \"{model_file}\"")
model, tokenizer = LlamaCppModel.from_pretrained(model_file) model, tokenizer = LlamaCppModel.from_pretrained(model_file)
return model, tokenizer return model, tokenizer
@ -257,7 +257,7 @@ def llamacpp_HF_loader(model_name):
for fname in [model_name, "oobabooga_llama-tokenizer", "llama-tokenizer"]: for fname in [model_name, "oobabooga_llama-tokenizer", "llama-tokenizer"]:
path = Path(f'{shared.args.model_dir}/{fname}') path = Path(f'{shared.args.model_dir}/{fname}')
if all((path / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']): if all((path / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']):
logger.info(f'Using tokenizer from: {path}') logger.info(f'Using tokenizer from: \"{path}\"')
break break
else: else:
logger.error("Could not load the model because a tokenizer in transformers format was not found. Please download oobabooga/llama-tokenizer.") logger.error("Could not load the model because a tokenizer in transformers format was not found. Please download oobabooga/llama-tokenizer.")
@ -298,7 +298,7 @@ def ctransformers_loader(model_name):
logger.error("Could not find a model for ctransformers.") logger.error("Could not find a model for ctransformers.")
return None, None return None, None
logger.info(f'ctransformers weights detected: {model_file}') logger.info(f'ctransformers weights detected: \"{model_file}\"')
model, tokenizer = ctrans.from_pretrained(model_file) model, tokenizer = ctrans.from_pretrained(model_file)
return model, tokenizer return model, tokenizer
@ -393,7 +393,7 @@ def HQQ_loader(model_name):
from hqq.core.quantize import HQQBackend, HQQLinear from hqq.core.quantize import HQQBackend, HQQLinear
from hqq.engine.hf import HQQModelForCausalLM from hqq.engine.hf import HQQModelForCausalLM
logger.info(f"Loading HQQ model with backend: {shared.args.hqq_backend}") logger.info(f"Loading HQQ model with backend: \"{shared.args.hqq_backend}\"")
model_dir = Path(f'{shared.args.model_dir}/{model_name}') model_dir = Path(f'{shared.args.model_dir}/{model_name}')
model = HQQModelForCausalLM.from_quantized(str(model_dir)) model = HQQModelForCausalLM.from_quantized(str(model_dir))

View file

@ -187,7 +187,7 @@ if __name__ == "__main__":
settings_file = Path('settings.json') settings_file = Path('settings.json')
if settings_file is not None: if settings_file is not None:
logger.info(f"Loading settings from {settings_file}") logger.info(f"Loading settings from \"{settings_file}\"")
file_contents = open(settings_file, 'r', encoding='utf-8').read() file_contents = open(settings_file, 'r', encoding='utf-8').read()
new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents) new_settings = json.loads(file_contents) if settings_file.suffix == "json" else yaml.safe_load(file_contents)
shared.settings.update(new_settings) shared.settings.update(new_settings)