AutoGPTQ: Add --disable_exllamav2 flag (Mixtral CPU offloading needs this)

This commit is contained in:
oobabooga 2023-12-15 06:46:13 -08:00
parent 7de10f4c8e
commit 3bbf6c601d
7 changed files with 16 additions and 4 deletions

View file

@ -285,6 +285,7 @@ List of command-line flags
| `--no_use_cuda_fp16` | This can make models faster on some systems. | | `--no_use_cuda_fp16` | This can make models faster on some systems. |
| `--desc_act` | For models that don't have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig. | | `--desc_act` | For models that don't have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig. |
| `--disable_exllama` | Disable ExLlama kernel, which can improve inference speed on some systems. | | `--disable_exllama` | Disable ExLlama kernel, which can improve inference speed on some systems. |
| `--disable_exllamav2` | Disable ExLlamav2 kernel. |
#### GPTQ-for-LLaMa #### GPTQ-for-LLaMa

View file

@ -52,6 +52,7 @@ def load_quantized(model_name):
'quantize_config': quantize_config, 'quantize_config': quantize_config,
'use_cuda_fp16': not shared.args.no_use_cuda_fp16, 'use_cuda_fp16': not shared.args.no_use_cuda_fp16,
'disable_exllama': shared.args.disable_exllama, 'disable_exllama': shared.args.disable_exllama,
'disable_exllamav2': shared.args.disable_exllamav2,
} }
logger.info(f"The AutoGPTQ params are: {params}") logger.info(f"The AutoGPTQ params are: {params}")

View file

@ -25,6 +25,7 @@ loaders_and_params = OrderedDict({
'rope_freq_base', 'rope_freq_base',
'compress_pos_emb', 'compress_pos_emb',
'disable_exllama', 'disable_exllama',
'disable_exllamav2',
'transformers_info' 'transformers_info'
], ],
'llama.cpp': [ 'llama.cpp': [
@ -94,6 +95,7 @@ loaders_and_params = OrderedDict({
'groupsize', 'groupsize',
'desc_act', 'desc_act',
'disable_exllama', 'disable_exllama',
'disable_exllamav2',
'gpu_memory', 'gpu_memory',
'cpu_memory', 'cpu_memory',
'cpu', 'cpu',

View file

@ -156,7 +156,7 @@ def huggingface_loader(model_name):
LoaderClass = AutoModelForCausalLM LoaderClass = AutoModelForCausalLM
# Load the model in simple 16-bit mode by default # Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama, shared.args.disable_exllamav2]):
model = LoaderClass.from_pretrained(path_to_model, **params) model = LoaderClass.from_pretrained(path_to_model, **params)
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
device = torch.device('mps') device = torch.device('mps')
@ -221,11 +221,16 @@ def huggingface_loader(model_name):
if shared.args.disk: if shared.args.disk:
params['offload_folder'] = shared.args.disk_cache_dir params['offload_folder'] = shared.args.disk_cache_dir
if shared.args.disable_exllama: if shared.args.disable_exllama or shared.args.disable_exllamav2:
try: try:
gptq_config = GPTQConfig(bits=config.quantization_config.get('bits', 4), disable_exllama=True) gptq_config = GPTQConfig(
bits=config.quantization_config.get('bits', 4),
disable_exllama=shared.args.disable_exllama,
disable_exllamav2=shared.args.disable_exllamav2,
)
params['quantization_config'] = gptq_config params['quantization_config'] = gptq_config
logger.info('Loading with ExLlama kernel disabled.') logger.info(f'Loading with disable_exllama={shared.args.disable_exllama} and disable_exllamav2={shared.args.disable_exllamav2}.')
except: except:
exc = traceback.format_exc() exc = traceback.format_exc()
logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?') logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?')

View file

@ -133,6 +133,7 @@ parser.add_argument('--no_inject_fused_mlp', action='store_true', help='Triton m
parser.add_argument('--no_use_cuda_fp16', action='store_true', help='This can make models faster on some systems.') parser.add_argument('--no_use_cuda_fp16', action='store_true', help='This can make models faster on some systems.')
parser.add_argument('--desc_act', action='store_true', help='For models that do not have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.') parser.add_argument('--desc_act', action='store_true', help='For models that do not have a quantize_config.json, this parameter is used to define whether to set desc_act or not in BaseQuantizeConfig.')
parser.add_argument('--disable_exllama', action='store_true', help='Disable ExLlama kernel, which can improve inference speed on some systems.') parser.add_argument('--disable_exllama', action='store_true', help='Disable ExLlama kernel, which can improve inference speed on some systems.')
parser.add_argument('--disable_exllamav2', action='store_true', help='Disable ExLlamav2 kernel.')
# GPTQ-for-LLaMa # GPTQ-for-LLaMa
parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')

View file

@ -70,6 +70,7 @@ def list_model_elements():
'no_inject_fused_mlp', 'no_inject_fused_mlp',
'no_use_cuda_fp16', 'no_use_cuda_fp16',
'disable_exllama', 'disable_exllama',
'disable_exllamav2',
'cfg_cache', 'cfg_cache',
'no_flash_attn', 'no_flash_attn',
'cache_8bit', 'cache_8bit',

View file

@ -125,6 +125,7 @@ def create_ui():
shared.gradio['logits_all'] = gr.Checkbox(label="logits_all", value=shared.args.logits_all, info='Needs to be set for perplexity evaluation to work. Otherwise, ignore it, as it makes prompt processing slower.') shared.gradio['logits_all'] = gr.Checkbox(label="logits_all", value=shared.args.logits_all, info='Needs to be set for perplexity evaluation to work. Otherwise, ignore it, as it makes prompt processing slower.')
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.') shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.') shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.')
shared.gradio['disable_exllamav2'] = gr.Checkbox(label="disable_exllamav2", value=shared.args.disable_exllamav2, info='Disable ExLlamav2 kernel.')
shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.') shared.gradio['no_flash_attn'] = gr.Checkbox(label="no_flash_attn", value=shared.args.no_flash_attn, info='Force flash-attention to not be used.')
shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.') shared.gradio['cache_8bit'] = gr.Checkbox(label="cache_8bit", value=shared.args.cache_8bit, info='Use 8-bit cache to save VRAM.')
shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.') shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.')