Fix settings.json being ignored because of config.yaml

This commit is contained in:
oobabooga 2023-05-12 06:09:45 -03:00
parent a77965e801
commit 5eaa914e1b
3 changed files with 19 additions and 10 deletions

View file

@ -1,11 +1,3 @@
.*:
wbits: 'None'
model_type: 'None'
groupsize: 'None'
pre_layer: 0
mode: 'chat'
skip_special_tokens: true
custom_stopping_strings: ''
.*(llama|alpac|vicuna|guanaco|koala|llava|wizardlm|metharme|pygmalion-7b):
model_type: 'llama'
.*(opt-|opt_|opt1|opt3|optfor|galactica|galpaca|pygmalion-350m):
@ -28,7 +20,7 @@
wbits: 6
.*(-5bit|_5bit|int5-):
wbits: 5
.*(-gr32-|-32g-|groupsize32):
.*(-gr32-|-32g-|groupsize32|-32g$):
groupsize: 32
.*(-gr64-|-64g-|groupsize64):
groupsize: 64

View file

@ -1,5 +1,6 @@
import argparse
import logging
from collections import OrderedDict
from pathlib import Path
import yaml
@ -200,7 +201,7 @@ def is_chat():
return args.chat
# Loading model-specific settings (default)
# Loading model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p:
if p.exists():
model_config = yaml.safe_load(open(p, 'r').read())
@ -216,3 +217,5 @@ with Path(f'{args.model_dir}/config-user.yaml') as p:
model_config[k].update(user_config[k])
else:
model_config[k] = user_config[k]
model_config = OrderedDict(model_config)

View file

@ -878,12 +878,26 @@ if __name__ == "__main__":
settings_file = Path(shared.args.settings)
elif Path('settings.json').exists():
settings_file = Path('settings.json')
if settings_file is not None:
logging.info(f"Loading settings from {settings_file}...")
new_settings = json.loads(open(settings_file, 'r').read())
for item in new_settings:
shared.settings[item] = new_settings[item]
# Set default model settings based on settings.json
shared.model_config['.*'] = {
'wbits': 'None',
'model_type': 'None',
'groupsize': 'None',
'pre_layer': 0,
'mode': shared.settings['mode'],
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
}
shared.model_config.move_to_end('.*', last=False) # Move to the beginning
# Default extensions
extensions_module.available_extensions = utils.get_available_extensions()
if shared.is_chat():