diff --git a/modules/models.py b/modules/models.py index 04235b52..e10668cf 100644 --- a/modules/models.py +++ b/modules/models.py @@ -102,6 +102,10 @@ def load_model(model_name): if path_to_model.name.lower().startswith('llama-30b'): pt_model = 'llama-30b-4bit.pt' + if not Path(f"models/{pt_model}").exists(): + print(f"Could not find models/{pt_model}, exiting...") + exit() + model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4) model = model.to(torch.device('cuda:0')) @@ -178,4 +182,3 @@ def load_soft_prompt(name): shared.soft_prompt_tensor = tensor return name -