Read Transformers config.json metadata

This commit is contained in:
oobabooga 2023-09-28 19:19:47 -07:00
parent 9ccaf5eebb
commit 1dd13e4643
2 changed files with 11 additions and 20 deletions

View file

@ -62,7 +62,6 @@ llama-65b-gptq-3bit:
instruction_template: 'Vicuna-v1.1' instruction_template: 'Vicuna-v1.1'
.*vicuna.*(1.5|1_5): .*vicuna.*(1.5|1_5):
instruction_template: 'Vicuna-v1.1' instruction_template: 'Vicuna-v1.1'
truncation_length: 4096
.*stable.*vicuna: .*stable.*vicuna:
instruction_template: 'StableVicuna' instruction_template: 'StableVicuna'
(?!.*chat).*chinese-vicuna: (?!.*chat).*chinese-vicuna:
@ -93,15 +92,10 @@ llama-65b-gptq-3bit:
custom_stopping_strings: '"\n###"' custom_stopping_strings: '"\n###"'
.*raven: .*raven:
instruction_template: 'RWKV-Raven' instruction_template: 'RWKV-Raven'
.*ctx8192:
truncation_length: 8192
.*moss-moon.*sft: .*moss-moon.*sft:
instruction_template: 'MOSS' instruction_template: 'MOSS'
.*stablelm-tuned: .*stablelm-tuned:
instruction_template: 'StableLM' instruction_template: 'StableLM'
truncation_length: 4096
.*stablelm-base:
truncation_length: 4096
.*galactica.*finetuned: .*galactica.*finetuned:
instruction_template: 'Galactica Finetuned' instruction_template: 'Galactica Finetuned'
.*galactica.*-v2: .*galactica.*-v2:
@ -147,7 +141,6 @@ llama-65b-gptq-3bit:
instruction_template: 'Manticore Chat' instruction_template: 'Manticore Chat'
.*bluemoonrp-(30|13)b: .*bluemoonrp-(30|13)b:
instruction_template: 'Bluemoon' instruction_template: 'Bluemoon'
truncation_length: 4096
.*Nous-Hermes-13b: .*Nous-Hermes-13b:
instruction_template: 'Alpaca' instruction_template: 'Alpaca'
.*airoboros: .*airoboros:
@ -181,16 +174,8 @@ llama-65b-gptq-3bit:
custom_stopping_strings: '"<|end|>"' custom_stopping_strings: '"<|end|>"'
.*minotaur: .*minotaur:
instruction_template: 'Minotaur' instruction_template: 'Minotaur'
.*minotaur-15b:
truncation_length: 8192
.*orca_mini: .*orca_mini:
instruction_template: 'Orca Mini' instruction_template: 'Orca Mini'
.*landmark:
truncation_length: 8192
.*superhot-8k:
truncation_length: 8192
.*xgen.*-inst:
truncation_length: 8192
instruction_template: 'Vicuna-v0' instruction_template: 'Vicuna-v0'
.*(platypus|gplatty|superplatty): .*(platypus|gplatty|superplatty):
instruction_template: 'Alpaca' instruction_template: 'Alpaca'
@ -200,23 +185,18 @@ llama-65b-gptq-3bit:
instruction_template: 'Vicuna-v1.1' instruction_template: 'Vicuna-v1.1'
.*redmond-hermes-coder: .*redmond-hermes-coder:
instruction_template: 'Alpaca' instruction_template: 'Alpaca'
truncation_length: 8192
.*wizardcoder-15b: .*wizardcoder-15b:
instruction_template: 'Alpaca' instruction_template: 'Alpaca'
truncation_length: 8192
.*wizardlm: .*wizardlm:
instruction_template: 'Vicuna-v1.1' instruction_template: 'Vicuna-v1.1'
.*godzilla: .*godzilla:
instruction_template: 'Alpaca' instruction_template: 'Alpaca'
.*llama-(2|v2):
truncation_length: 4096
.*llama(-?)(2|v2).*chat: .*llama(-?)(2|v2).*chat:
instruction_template: 'Llama-v2' instruction_template: 'Llama-v2'
.*newhope: .*newhope:
instruction_template: 'NewHope' instruction_template: 'NewHope'
.*stablebeluga2: .*stablebeluga2:
instruction_template: 'StableBeluga2' instruction_template: 'StableBeluga2'
truncation_length: 4096
.*openchat: .*openchat:
instruction_template: 'OpenChat' instruction_template: 'OpenChat'
.*falcon.*-instruct: .*falcon.*-instruct:

View file

@ -1,3 +1,4 @@
import json
import re import re
from pathlib import Path from pathlib import Path
@ -15,6 +16,7 @@ def get_fallback_settings():
'skip_special_tokens': shared.settings['skip_special_tokens'], 'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'], 'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'truncation_length': shared.settings['truncation_length'], 'truncation_length': shared.settings['truncation_length'],
'max_seq_len': 2048,
'n_ctx': 2048, 'n_ctx': 2048,
'rope_freq_base': 0, 'rope_freq_base': 0,
'compress_pos_emb': 1, 'compress_pos_emb': 1,
@ -54,6 +56,15 @@ def get_model_metadata(model):
if 'llama.rope.freq_base' in metadata: if 'llama.rope.freq_base' in metadata:
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
# Read transformers metadata. In particular, the sequence length for the model.
else:
path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
metadata = json.loads(open(path, 'r').read())
if 'max_position_embeddings' in metadata:
model_settings['truncation_length'] = metadata['max_position_embeddings']
model_settings['max_seq_len'] = metadata['max_position_embeddings']
# Apply user settings from models/config-user.yaml # Apply user settings from models/config-user.yaml
settings = shared.user_config settings = shared.user_config
for pat in settings: for pat in settings: