Add disable_exllama to Transformers loader (for GPTQ LoRA training)

This commit is contained in:
oobabooga 2023-09-24 20:03:11 -07:00
parent c0fca23cb9
commit 36c38d7561
3 changed files with 22 additions and 10 deletions

View file

@ -23,6 +23,7 @@ loaders_and_params = OrderedDict({
'alpha_value',
'rope_freq_base',
'compress_pos_emb',
'disable_exllama',
'transformers_info'
],
'ExLlama_HF': [

View file

@ -13,7 +13,8 @@ from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BitsAndBytesConfig
BitsAndBytesConfig,
GPTQConfig
)
import modules.shared as shared
@ -114,11 +115,13 @@ def load_tokenizer(model_name, model):
def huggingface_loader(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
if 'chatglm' in model_name.lower():
LoaderClass = AutoModel
else:
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
if config.to_dict().get("is_encoder_decoder", False):
LoaderClass = AutoModelForSeq2SeqLM
shared.is_seq2seq = True
@ -126,7 +129,7 @@ def huggingface_loader(model_name):
LoaderClass = AutoModelForCausalLM
# Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1]):
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]):
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
if torch.backends.mps.is_available():
device = torch.device('mps')
@ -170,9 +173,10 @@ def huggingface_loader(model_name):
logger.warning("Using the following 4-bit params: " + str(quantization_config_params))
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
elif shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
elif shared.args.load_in_8bit:
if any((shared.args.auto_devices, shared.args.gpu_memory)):
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
else:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
elif shared.args.bf16:
params["torch_dtype"] = torch.bfloat16
@ -183,9 +187,16 @@ def huggingface_loader(model_name):
if shared.args.disk:
params["offload_folder"] = shared.args.disk_cache_dir
checkpoint = Path(f'{shared.args.model_dir}/{model_name}')
if shared.args.disable_exllama:
try:
gptq_config = GPTQConfig(bits=config.quantization_config.get('bits', 4), disable_exllama=True)
params['quantization_config'] = gptq_config
logger.info('Loading with ExLlama kernel disabled.')
except:
logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=shared.args.trust_remote_code)
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
with init_empty_weights():
model = LoaderClass.from_config(config, trust_remote_code=shared.args.trust_remote_code)
@ -202,7 +213,7 @@ def huggingface_loader(model_name):
elif shared.args.alpha_value > 1:
params['rope_scaling'] = {'type': 'dynamic', 'factor': RoPE.get_alpha_value(shared.args.alpha_value, shared.args.rope_freq_base)}
model = LoaderClass.from_pretrained(checkpoint, **params)
model = LoaderClass.from_pretrained(path_to_model, **params)
return model

View file

@ -100,7 +100,6 @@ def create_ui():
shared.gradio['no_inject_fused_mlp'] = gr.Checkbox(label="no_inject_fused_mlp", value=shared.args.no_inject_fused_mlp, info='Affects Triton only. Disable fused MLP. Fused MLP improves performance but uses more VRAM. Disable if running low on VRAM.')
shared.gradio['no_use_cuda_fp16'] = gr.Checkbox(label="no_use_cuda_fp16", value=shared.args.no_use_cuda_fp16, info='This can make models faster on some systems.')
shared.gradio['desc_act'] = gr.Checkbox(label="desc_act", value=shared.args.desc_act, info='\'desc_act\', \'wbits\', and \'groupsize\' are used for old models without a quantize_config.json.')
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel, which can improve inference speed on some systems.')
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu)
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16)
@ -116,6 +115,7 @@ def create_ui():
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='Split the model across multiple GPUs, comma-separated list of proportions, e.g. 18,17')
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
shared.gradio['disable_exllama'] = gr.Checkbox(label="disable_exllama", value=shared.args.disable_exllama, info='Disable ExLlama kernel.')
shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa support is currently only kept for compatibility with older GPUs. AutoGPTQ or ExLlama is preferred when compatible. GPTQ-for-LLaMa is installed by default with the webui on supported systems. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).')
shared.gradio['exllama_info'] = gr.Markdown('For more information, consult the [docs](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).')
shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s a bit slower than the regular ExLlama.')