From 1dd13e464305bd38fa33a25ddaa563db693c5adc Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 28 Sep 2023 19:19:47 -0700 Subject: [PATCH] Read Transformers config.json metadata --- models/config.yaml | 20 -------------------- modules/models_settings.py | 11 +++++++++++ 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/models/config.yaml b/models/config.yaml index 704012ac..68a78eb1 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -62,7 +62,6 @@ llama-65b-gptq-3bit: instruction_template: 'Vicuna-v1.1' .*vicuna.*(1.5|1_5): instruction_template: 'Vicuna-v1.1' - truncation_length: 4096 .*stable.*vicuna: instruction_template: 'StableVicuna' (?!.*chat).*chinese-vicuna: @@ -93,15 +92,10 @@ llama-65b-gptq-3bit: custom_stopping_strings: '"\n###"' .*raven: instruction_template: 'RWKV-Raven' -.*ctx8192: - truncation_length: 8192 .*moss-moon.*sft: instruction_template: 'MOSS' .*stablelm-tuned: instruction_template: 'StableLM' - truncation_length: 4096 -.*stablelm-base: - truncation_length: 4096 .*galactica.*finetuned: instruction_template: 'Galactica Finetuned' .*galactica.*-v2: @@ -147,7 +141,6 @@ llama-65b-gptq-3bit: instruction_template: 'Manticore Chat' .*bluemoonrp-(30|13)b: instruction_template: 'Bluemoon' - truncation_length: 4096 .*Nous-Hermes-13b: instruction_template: 'Alpaca' .*airoboros: @@ -181,16 +174,8 @@ llama-65b-gptq-3bit: custom_stopping_strings: '"<|end|>"' .*minotaur: instruction_template: 'Minotaur' -.*minotaur-15b: - truncation_length: 8192 .*orca_mini: instruction_template: 'Orca Mini' -.*landmark: - truncation_length: 8192 -.*superhot-8k: - truncation_length: 8192 -.*xgen.*-inst: - truncation_length: 8192 instruction_template: 'Vicuna-v0' .*(platypus|gplatty|superplatty): instruction_template: 'Alpaca' @@ -200,23 +185,18 @@ llama-65b-gptq-3bit: instruction_template: 'Vicuna-v1.1' .*redmond-hermes-coder: instruction_template: 'Alpaca' - truncation_length: 8192 .*wizardcoder-15b: instruction_template: 'Alpaca' - truncation_length: 8192 .*wizardlm: instruction_template: 'Vicuna-v1.1' .*godzilla: instruction_template: 'Alpaca' -.*llama-(2|v2): - truncation_length: 4096 .*llama(-?)(2|v2).*chat: instruction_template: 'Llama-v2' .*newhope: instruction_template: 'NewHope' .*stablebeluga2: instruction_template: 'StableBeluga2' - truncation_length: 4096 .*openchat: instruction_template: 'OpenChat' .*falcon.*-instruct: diff --git a/modules/models_settings.py b/modules/models_settings.py index 537bf0ab..b3611a94 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -1,3 +1,4 @@ +import json import re from pathlib import Path @@ -15,6 +16,7 @@ def get_fallback_settings(): 'skip_special_tokens': shared.settings['skip_special_tokens'], 'custom_stopping_strings': shared.settings['custom_stopping_strings'], 'truncation_length': shared.settings['truncation_length'], + 'max_seq_len': 2048, 'n_ctx': 2048, 'rope_freq_base': 0, 'compress_pos_emb': 1, @@ -54,6 +56,15 @@ def get_model_metadata(model): if 'llama.rope.freq_base' in metadata: model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] + # Read transformers metadata. In particular, the sequence length for the model. + else: + path = Path(f'{shared.args.model_dir}/{model}/config.json') + if path.exists(): + metadata = json.loads(open(path, 'r').read()) + if 'max_position_embeddings' in metadata: + model_settings['truncation_length'] = metadata['max_position_embeddings'] + model_settings['max_seq_len'] = metadata['max_position_embeddings'] + # Apply user settings from models/config-user.yaml settings = shared.user_config for pat in settings: