Read more metadata (config.json & quantize_config.json)

This commit is contained in:
oobabooga 2023-09-29 06:14:16 -07:00
parent 56b5a4af74
commit 96da2e1c0d
3 changed files with 59 additions and 67 deletions

View file

@ -6,8 +6,6 @@
model_type: 'gptj'
.*(gpt-neox|koalpaca-polyglot|polyglot.*koalpaca|polyglot-ko|polyglot_ko|pythia|stablelm|incite|dolly-v2|polycoder|h2ogpt-oig|h2ogpt-oasst1|h2ogpt-gm):
model_type: 'gptneox'
.*llama:
model_type: 'llama'
.*bloom:
model_type: 'bloom'
.*gpt2:
@ -22,38 +20,32 @@
model_type: 'dollyv2'
.*replit:
model_type: 'replit'
llama-65b-gptq-3bit:
groupsize: 'None'
.*(4bit|int4):
wbits: 4
.*(3bit|int3):
wbits: 3
.*(-2bit|_2bit|int2-):
wbits: 2
.*(-1bit|_1bit|int1-):
wbits: 1
.*(8bit|int8):
wbits: 8
.*(-7bit|_7bit|int7-):
wbits: 7
.*(-6bit|_6bit|int6-):
wbits: 6
.*(-5bit|_5bit|int5-):
wbits: 5
.*(-gr32-|-32g-|groupsize32|-32g$):
groupsize: 32
.*(-gr64-|-64g-|groupsize64|-64g$):
groupsize: 64
.*(gr128|128g|groupsize128):
groupsize: 128
.*(gr1024|1024g|groupsize1024):
groupsize: 1024
.*(oasst|openassistant-|stablelm-7b-sft-v7-epoch-3):
instruction_template: 'Open Assistant'
skip_special_tokens: false
(?!.*galactica)(?!.*reward).*openassistant:
instruction_template: 'Open Assistant'
skip_special_tokens: false
.*galactica:
skip_special_tokens: false
.*dolly-v[0-9]-[0-9]*b:
instruction_template: 'Alpaca'
skip_special_tokens: false
.*alpaca-native-4bit:
instruction_template: 'Alpaca'
custom_stopping_strings: '"### End"'
.*llava:
instruction_template: 'LLaVA'
custom_stopping_strings: '"\n###"'
.*wizard.*mega:
instruction_template: 'Wizard-Mega'
custom_stopping_strings: '"</s>"'
.*starchat-beta:
instruction_template: 'Starchat-Beta'
custom_stopping_strings: '"<|end|>"'
.*(openorca-platypus2):
instruction_template: 'OpenOrca-Platypus2'
custom_stopping_strings: '"### Instruction:", "### Response:"'
(?!.*v0)(?!.*1.1)(?!.*1_1)(?!.*stable)(?!.*chinese).*vicuna:
instruction_template: 'Vicuna-v0'
.*vicuna.*v0:
@ -70,26 +62,12 @@ llama-65b-gptq-3bit:
instruction_template: 'Chinese-Vicuna-Chat'
.*alpaca:
instruction_template: 'Alpaca'
.*alpaca-native-4bit:
instruction_template: 'Alpaca'
wbits: 4
groupsize: 128
.*galactica:
skip_special_tokens: false
.*dolly-v[0-9]-[0-9]*b:
instruction_template: 'Alpaca'
skip_special_tokens: false
custom_stopping_strings: '"### End"'
.*koala:
instruction_template: 'Koala'
.*chatglm:
instruction_template: 'ChatGLM'
.*(metharme|pygmalion|mythalion):
instruction_template: 'Metharme'
.*llava:
model_type: 'llama'
instruction_template: 'LLaVA'
custom_stopping_strings: '"\n###"'
.*raven:
instruction_template: 'RWKV-Raven'
.*moss-moon.*sft:
@ -116,9 +94,6 @@ llama-65b-gptq-3bit:
instruction_template: 'INCITE-Chat'
.*incite.*instruct:
instruction_template: 'INCITE-Instruct'
.*wizard.*mega:
instruction_template: 'Wizard-Mega'
custom_stopping_strings: '"</s>"'
.*ziya-:
instruction_template: 'Ziya'
.*koalpaca:
@ -169,14 +144,10 @@ llama-65b-gptq-3bit:
instruction_template: 'Samantha'
.*wizardcoder:
instruction_template: 'Alpaca'
.*starchat-beta:
instruction_template: 'Starchat-Beta'
custom_stopping_strings: '"<|end|>"'
.*minotaur:
instruction_template: 'Minotaur'
.*orca_mini:
instruction_template: 'Orca Mini'
instruction_template: 'Vicuna-v0'
.*(platypus|gplatty|superplatty):
instruction_template: 'Alpaca'
.*longchat:
@ -199,12 +170,6 @@ llama-65b-gptq-3bit:
instruction_template: 'StableBeluga2'
.*openchat:
instruction_template: 'OpenChat'
.*falcon.*-instruct:
.*(openorca-platypus2):
instruction_template: 'OpenOrca-Platypus2'
custom_stopping_strings: '"### Instruction:", "### Response:"'
.*codellama:
rope_freq_base: 1000000
.*codellama.*instruct:
instruction_template: 'Llama-v2'
.*mistral.*instruct:

View file

@ -10,16 +10,16 @@ from modules import loaders, metadata_gguf, shared, ui
def get_fallback_settings():
return {
'wbits': 'None',
'model_type': 'None',
'groupsize': 'None',
'pre_layer': 0,
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'truncation_length': shared.settings['truncation_length'],
'desc_act': False,
'model_type': 'None',
'max_seq_len': 2048,
'n_ctx': 2048,
'rope_freq_base': 0,
'compress_pos_emb': 1,
'truncation_length': shared.settings['truncation_length'],
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
}
@ -56,8 +56,8 @@ def get_model_metadata(model):
if 'llama.rope.freq_base' in metadata:
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
# Read transformers metadata. In particular, the sequence length for the model.
else:
# Read transformers metadata
path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
metadata = json.loads(open(path, 'r').read())
@ -65,6 +65,32 @@ def get_model_metadata(model):
model_settings['truncation_length'] = metadata['max_position_embeddings']
model_settings['max_seq_len'] = metadata['max_position_embeddings']
if 'rope_theta' in metadata:
model_settings['rope_freq_base'] = metadata['rope_theta']
if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
if metadata['rope_scaling']['type'] == 'linear':
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
if 'quantization_config' in metadata:
if 'bits' in metadata['quantization_config']:
model_settings['wbits'] = metadata['quantization_config']['bits']
if 'group_size' in metadata['quantization_config']:
model_settings['groupsize'] = metadata['quantization_config']['group_size']
if 'desc_act' in metadata['quantization_config']:
model_settings['desc_act'] = metadata['quantization_config']['desc_act']
# Read AutoGPTQ metadata
path = Path(f'{shared.args.model_dir}/{model}/quantize_config.json')
if path.exists():
metadata = json.loads(open(path, 'r').read())
if 'bits' in metadata:
model_settings['wbits'] = metadata['bits']
if 'group_size' in metadata:
model_settings['groupsize'] = metadata['group_size']
if 'desc_act' in metadata:
model_settings['desc_act'] = metadata['desc_act']
# Apply user settings from models/config-user.yaml
settings = shared.user_config
for pat in settings:

View file

@ -258,9 +258,10 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
def update_truncation_length(current_length, state):
if 'loader' in state:
if state['loader'].lower().startswith('exllama'):
return state['max_seq_len']
elif state['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
return state['n_ctx']
else:
return current_length