From 39099663a06abf7e19b6448f964f204096f5888b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 16 Apr 2023 23:26:52 -0300 Subject: [PATCH] Add 4-bit LoRA support (#1200) --- README.md | 1 + modules/GPTQ_loader.py | 70 ++++++++++++++++++------------- modules/LoRA.py | 3 +- modules/models.py | 15 ++++++- modules/monkey_patch_gptq_lora.py | 41 ++++++++++++++++++ modules/shared.py | 1 + requirements.txt | 3 +- 7 files changed, 100 insertions(+), 34 deletions(-) create mode 100644 modules/monkey_patch_gptq_lora.py diff --git a/README.md b/README.md index fbf88023..1e680913 100644 --- a/README.md +++ b/README.md @@ -237,6 +237,7 @@ Optionally, you can use the following command-line flags: | `--groupsize GROUPSIZE` | GPTQ: Group size. | | `--pre_layer PRE_LAYER` | GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. | | `--no-warmup_autotune` | GPTQ: Disable warmup autotune for triton. | +| `--monkey-patch` | GPTQ: Apply the monkey patch for using LoRAs with quantized models. | #### FlexGen diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 8e0066a9..344e34dd 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -16,6 +16,8 @@ from modelutils import find_layers from quant import make_quant +# This function is a replacement for the load_quant function in the +# GPTQ-for_LLaMa repository. It supports more models and branches. def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): def noop(*args, **kwargs): @@ -64,6 +66,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc try: from quant import autotune_warmup, make_quant_attn + # triton branch make_quant_attn(model) if not shared.args.no_warmup_autotune: @@ -77,6 +80,41 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc return model +# Used to locate the .pt/.safetensors quantized file +def find_quantized_model_file(model_name): + path_to_model = Path(f'{shared.args.model_dir}/{model_name}') + pt_path = None + priority_name_list = [ + Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}') + for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else ['']) + for ext in ['.safetensors', '.pt'] + for hyphen in ['-', f'/{model_name}-', '/'] + ] + for path in priority_name_list: + if path.exists(): + pt_path = path + break + + # If the model hasn't been found with a well-behaved name, pick the last .pt + # or the last .safetensors found in its folder as a last resort + if not pt_path: + found_pts = list(path_to_model.glob("*.pt")) + found_safetensors = list(path_to_model.glob("*.safetensors")) + pt_path = None + + if len(found_pts) > 0: + if len(found_pts) > 1: + print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.') + pt_path = found_pts[-1] + elif len(found_safetensors) > 0: + if len(found_pts) > 1: + print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.') + pt_path = found_safetensors[-1] + + return pt_path + + +# The function that loads the model in modules/models.py def load_quantized(model_name): # Find the model type @@ -106,37 +144,9 @@ def load_quantized(model_name): print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") exit() - # Locate the quantized model file + # Find the quantized model weights file (.pt/.safetensors) path_to_model = Path(f'{shared.args.model_dir}/{model_name}') - pt_path = None - priority_name_list = [ - Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}') - for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else ['']) - for ext in ['.safetensors', '.pt'] - for hyphen in ['-', f'/{model_name}-', '/'] - ] - for path in priority_name_list: - if path.exists(): - pt_path = path - break - - # If the model hasn't been found with a well-behaved name, pick the last .pt - # or the last .safetensors found in its folder as a last resort - if not pt_path: - path_to_model = Path(f'{shared.args.model_dir}/{model_name}') - found_pts = list(path_to_model.glob("*.pt")) - found_safetensors = list(path_to_model.glob("*.safetensors")) - pt_path = None - - if len(found_pts) > 0: - if len(found_pts) > 1: - print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.') - pt_path = found_pts[-1] - elif len(found_safetensors) > 0: - if len(found_pts) > 1: - print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.') - pt_path = found_safetensors[-1] - + pt_path = find_quantized_model_file(model_name) if not pt_path: print("Could not find the quantized model in .pt or .safetensors format, exiting...") exit() diff --git a/modules/LoRA.py b/modules/LoRA.py index 8b54ef69..ef1e88aa 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -43,7 +43,8 @@ def add_lora_to_model(lora_names): shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora) if not shared.args.load_in_8bit and not shared.args.cpu: - shared.model.half() + if not shared.args.monkey_patch: + shared.model.half() if not hasattr(shared.model, "hf_device_map"): if torch.has_mps: device = torch.device('mps') diff --git a/modules/models.py b/modules/models.py index ca9498d2..2d3ce2ad 100644 --- a/modules/models.py +++ b/modules/models.py @@ -101,9 +101,20 @@ def load_model(model_name): # Quantized model elif shared.args.wbits > 0: - from modules.GPTQ_loader import load_quantized - model = load_quantized(model_name) + # Monkey patch + if shared.args.monkey_patch: + print("Warning: applying the monkey patch for using LoRAs in 4-bit mode.\nIt may cause undefined behavior outside its intended scope.") + from modules.monkey_patch_gptq_lora import load_model_llama + + model, tokenizer = load_model_llama(model_name) + return model, tokenizer + + # No monkey patch + else: + from modules.GPTQ_loader import load_quantized + + model = load_quantized(model_name) # llamacpp model elif shared.is_llamacpp: diff --git a/modules/monkey_patch_gptq_lora.py b/modules/monkey_patch_gptq_lora.py new file mode 100644 index 00000000..3e591b52 --- /dev/null +++ b/modules/monkey_patch_gptq_lora.py @@ -0,0 +1,41 @@ +# Copied from https://github.com/johnsmith0031/alpaca_lora_4bit + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path("repositories/alpaca_lora_4bit"))) + +import autograd_4bit +from autograd_4bit import (Autograd4bitQuantLinear, + load_llama_model_4bit_low_ram) +from monkeypatch.peft_tuners_lora_monkey_patch import ( + Linear4bitLt, replace_peft_model_with_gptq_lora_model) + +from modules import shared +from modules.GPTQ_loader import find_quantized_model_file + +replace_peft_model_with_gptq_lora_model() + +def load_model_llama(model_name): + + config_path = str(Path(f'{shared.args.model_dir}/{model_name}')) + model_path = str(find_quantized_model_file(model_name)) + model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=shared.args.groupsize, is_v1_model=False) + + for n, m in model.named_modules(): + if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt): + if m.is_v1_model: + m.zeros = m.zeros.half() + m.scales = m.scales.half() + m.bias = m.bias.half() + autograd_4bit.use_new = True + autograd_4bit.auto_switch = True + + try: + tokenizer.eos_token_id = 2 + tokenizer.bos_token_id = 1 + tokenizer.pad_token_id = 0 + except: + pass + + return model, tokenizer diff --git a/modules/shared.py b/modules/shared.py index 0294073c..92ac1dd2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -124,6 +124,7 @@ parser.add_argument('--model_type', type=str, help='GPTQ: Model type of pre-quan parser.add_argument('--groupsize', type=int, default=-1, help='GPTQ: Group size.') parser.add_argument('--pre_layer', type=int, default=0, help='GPTQ: The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models.') parser.add_argument('--no-warmup_autotune', action='store_true', help='GPTQ: Disable warmup autotune for triton.') +parser.add_argument('--monkey-patch', action='store_true', help='GPTQ: Apply the monkey patch for using LoRAs with quantized models.') # FlexGen parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.') diff --git a/requirements.txt b/requirements.txt index 64036d97..6c7e22ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ accelerate==0.18.0 +colorama datasets flexgen==0.1.7 -gradio==3.25 +gradio==3.25.0 markdown numpy Pillow>=9.5.0