From 63de9eb24ff7e90c3d0e34988910f35046d65c28 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 24 Sep 2023 20:23:05 -0700 Subject: [PATCH] Clean up the transformers loader --- modules/models.py | 67 ++++++++++++++++++++++------------------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/modules/models.py b/modules/models.py index c052514b..c0d867b7 100644 --- a/modules/models.py +++ b/modules/models.py @@ -2,6 +2,7 @@ import gc import os import re import time +import traceback from pathlib import Path import torch @@ -117,12 +118,17 @@ def load_tokenizer(model_name, model): def huggingface_loader(model_name): path_to_model = Path(f'{shared.args.model_dir}/{model_name}') - config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code) + params = { + 'low_cpu_mem_usage': True, + 'trust_remote_code': shared.args.trust_remote_code, + 'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16 + } + config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=params['trust_remote_code']) if 'chatglm' in model_name.lower(): LoaderClass = AutoModel else: - if config.to_dict().get("is_encoder_decoder", False): + if config.to_dict().get('is_encoder_decoder', False): LoaderClass = AutoModelForSeq2SeqLM shared.is_seq2seq = True else: @@ -130,7 +136,7 @@ def huggingface_loader(model_name): # Load the model in simple 16-bit mode by default if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]): - model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code) + model = LoaderClass.from_pretrained(path_to_model, **params) if torch.backends.mps.is_available(): device = torch.device('mps') model = model.to(device) @@ -139,28 +145,23 @@ def huggingface_loader(model_name): # DeepSpeed ZeRO-3 elif shared.args.deepspeed: - model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = LoaderClass.from_pretrained(path_to_model, torch_dtype=params['torch_dtype']) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference - logger.info(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") + logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}') - # Custom + # Load with quantization and/or offloading else: - params = { - "low_cpu_mem_usage": True, - "trust_remote_code": shared.args.trust_remote_code - } - if not any((shared.args.cpu, torch.cuda.is_available(), torch.backends.mps.is_available())): - logger.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.") + logger.warning('torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.') shared.args.cpu = True if shared.args.cpu: - params["torch_dtype"] = torch.float32 + params['torch_dtype'] = torch.float32 else: - params["device_map"] = 'auto' + params['device_map'] = 'auto' + params['max_memory'] = get_max_memory_dict() if shared.args.load_in_4bit: - # See https://github.com/huggingface/transformers/pull/23479/files # and https://huggingface.co/blog/4bit-transformers-bitsandbytes quantization_config_params = { @@ -170,7 +171,7 @@ def huggingface_loader(model_name): 'bnb_4bit_use_double_quant': shared.args.use_double_quant, } - logger.warning("Using the following 4-bit params: " + str(quantization_config_params)) + logger.info('Using the following 4-bit params: ' + str(quantization_config_params)) params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params) elif shared.args.load_in_8bit: @@ -178,14 +179,21 @@ def huggingface_loader(model_name): params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) else: params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) - elif shared.args.bf16: - params["torch_dtype"] = torch.bfloat16 - else: - params["torch_dtype"] = torch.float16 - params['max_memory'] = get_max_memory_dict() + if params['max_memory'] is not None: + with init_empty_weights(): + model = LoaderClass.from_config(config, trust_remote_code=params['trust_remote_code']) + + model.tie_weights() + params['device_map'] = infer_auto_device_map( + model, + dtype=torch.int8, + max_memory=params['max_memory'], + no_split_module_classes=model._no_split_modules + ) + if shared.args.disk: - params["offload_folder"] = shared.args.disk_cache_dir + params['offload_folder'] = shared.args.disk_cache_dir if shared.args.disable_exllama: try: @@ -193,20 +201,9 @@ def huggingface_loader(model_name): params['quantization_config'] = gptq_config logger.info('Loading with ExLlama kernel disabled.') except: + exc = traceback.format_exc() logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?') - - if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': - config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code) - with init_empty_weights(): - model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code) - - model.tie_weights() - params['device_map'] = infer_auto_device_map( - model, - dtype=torch.int8, - max_memory=params['max_memory'], - no_split_module_classes=model._no_split_modules - ) + print(exc) if shared.args.compress_pos_emb > 1: params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}