Add a --tokenizer-dir command-line flag for llamacpp_HF

This commit is contained in:
oobabooga 2024-08-06 19:41:18 -07:00
parent f106e780ba
commit e926c03b3d
2 changed files with 23 additions and 11 deletions

View file

@ -98,7 +98,7 @@ def load_model(model_name, loader=None):
if model is None: if model is None:
return None, None return None, None
else: else:
tokenizer = load_tokenizer(model_name, model) tokenizer = load_tokenizer(model_name)
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings}) shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'): if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt'):
@ -113,9 +113,13 @@ def load_model(model_name, loader=None):
return model, tokenizer return model, tokenizer
def load_tokenizer(model_name, model): def load_tokenizer(model_name, tokenizer_dir=None):
tokenizer = None if tokenizer_dir:
path_to_model = Path(tokenizer_dir)
else:
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/") path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
tokenizer = None
if path_to_model.exists(): if path_to_model.exists():
if shared.args.no_use_fast: if shared.args.no_use_fast:
logger.info('Loading the tokenizer with use_fast=False.') logger.info('Loading the tokenizer with use_fast=False.')
@ -278,8 +282,10 @@ def llamacpp_loader(model_name):
def llamacpp_HF_loader(model_name): def llamacpp_HF_loader(model_name):
from modules.llamacpp_hf import LlamacppHF from modules.llamacpp_hf import LlamacppHF
if shared.args.tokenizer_dir:
logger.info(f'Using tokenizer from: \"{shared.args.tokenizer_dir}\"')
else:
path = Path(f'{shared.args.model_dir}/{model_name}') path = Path(f'{shared.args.model_dir}/{model_name}')
# Check if a HF tokenizer is available for the model # Check if a HF tokenizer is available for the model
if all((path / file).exists() for file in ['tokenizer_config.json']): if all((path / file).exists() for file in ['tokenizer_config.json']):
logger.info(f'Using tokenizer from: \"{path}\"') logger.info(f'Using tokenizer from: \"{path}\"')
@ -288,6 +294,11 @@ def llamacpp_HF_loader(model_name):
return None, None return None, None
model = LlamacppHF.from_pretrained(model_name) model = LlamacppHF.from_pretrained(model_name)
if shared.args.tokenizer_dir:
tokenizer = load_tokenizer(model_name, tokenizer_dir=shared.args.tokenizer_dir)
return model, tokenizer
else:
return model return model

View file

@ -132,6 +132,7 @@ group.add_argument('--cache-capacity', type=str, help='Maximum cache capacity (l
group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') group.add_argument('--row_split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.')
group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.') group.add_argument('--attention-sink-size', type=int, default=5, help='StreamingLLM: number of sink tokens. Only used if the trimmed prompt does not share a prefix with the old prompt.')
group.add_argument('--tokenizer-dir', type=str, help='Load the tokenizer from this folder. Meant to be used with llamacpp_HF through the command-line.')
# ExLlamaV2 # ExLlamaV2
group = parser.add_argument_group('ExLlamaV2') group = parser.add_argument_group('ExLlamaV2')