Automatically set bf16 & use_eager_attention for Gemma-2

This commit is contained in:
oobabooga 2024-07-01 21:46:35 -07:00
parent 8074fba18d
commit 907137a13d

View file

@ -9,6 +9,8 @@ from modules import chat, loaders, metadata_gguf, shared, ui
def get_fallback_settings(): def get_fallback_settings():
return { return {
'bf16': False,
'use_eager_attention': False,
'wbits': 'None', 'wbits': 'None',
'groupsize': 'None', 'groupsize': 'None',
'desc_act': False, 'desc_act': False,
@ -97,10 +99,18 @@ def get_model_metadata(model):
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']: elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta'] model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')): if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
if metadata['rope_scaling']['type'] == 'linear': if metadata['rope_scaling']['type'] == 'linear':
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor'] model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
# For Gemma-2
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
model_settings['bf16'] = True
# For Gemma-2
if 'architectures' in metadata and isinstance(metadata['architectures'], list) and 'Gemma2ForCausalLM' in metadata['architectures']:
model_settings['use_eager_attention'] = True
# Read GPTQ metadata for old GPTQ loaders # Read GPTQ metadata for old GPTQ loaders
if 'quantization_config' in metadata and metadata['quantization_config'].get('quant_method', '') != 'exl2': if 'quantization_config' in metadata and metadata['quantization_config'].get('quant_method', '') != 'exl2':
if 'bits' in metadata['quantization_config']: if 'bits' in metadata['quantization_config']:
@ -133,7 +143,7 @@ def get_model_metadata(model):
for k in ['eos_token', 'bos_token']: for k in ['eos_token', 'bos_token']:
if k in metadata: if k in metadata:
value = metadata[k] value = metadata[k]
if type(value) is dict: if isinstance(value, dict):
value = value['content'] value = value['content']
template = template.replace(k, "'{}'".format(value)) template = template.replace(k, "'{}'".format(value))
@ -168,7 +178,7 @@ def infer_loader(model_name, model_settings):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}') path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
if not path_to_model.exists(): if not path_to_model.exists():
loader = None loader = None
elif (path_to_model / 'quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0): elif (path_to_model / 'quantize_config.json').exists() or ('wbits' in model_settings and isinstance(model_settings['wbits'], int) and model_settings['wbits'] > 0):
loader = 'ExLlamav2_HF' loader = 'ExLlamav2_HF'
elif (path_to_model / 'quant_config.json').exists() or re.match(r'.*-awq', model_name.lower()): elif (path_to_model / 'quant_config.json').exists() or re.match(r'.*-awq', model_name.lower()):
loader = 'AutoAWQ' loader = 'AutoAWQ'