From b9352edf12da186db55407ca2be1aaab1f6be083 Mon Sep 17 00:00:00 2001 From: anon-contributor-0 <160194672+anon-contributor-0@users.noreply.github.com> Date: Thu, 15 Feb 2024 20:26:11 -0500 Subject: [PATCH] Add an API endpoint to reload the last-used model --- extensions/openai/script.py | 17 ++++++++++++++++- modules/models.py | 6 +++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 03d99e8d..0d294643 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -26,7 +26,7 @@ from extensions.openai.tokens import token_count, token_decode, token_encode from extensions.openai.utils import _start_cloudflared from modules import shared from modules.logging_colors import logger -from modules.models import unload_model +from modules.models import unload_model, load_last_model from modules.text_generation import stop_everything_event from .typing import ( @@ -325,6 +325,21 @@ async def handle_load_model(request_data: LoadModelRequest): return HTTPException(status_code=400, detail="Failed to load the model.") +@app.post("/v1/internal/model/loadlast", dependencies=check_admin_key) +async def handle_load_last_model(): + ''' + This endpoint is experimental and may change in the future. + + Loads the last model used before it was unloaded. + ''' + try: + load_last_model() + return JSONResponse(content="OK") + except: + traceback.print_exc() + return HTTPException(status_code=400, detail="Failed to load the last-used model.") + + @app.post("/v1/internal/model/unload", dependencies=check_admin_key) async def handle_unload_model(): unload_model() diff --git a/modules/models.py b/modules/models.py index 07c14308..5c16376e 100644 --- a/modules/models.py +++ b/modules/models.py @@ -396,9 +396,13 @@ def unload_model(): clear_torch_cache() +def load_last_model(): + shared.model, shared.tokenizer = load_model(shared.previous_model_name) + + def reload_model(): unload_model() - shared.model, shared.tokenizer = load_model(shared.model_name) + shared.model, shared.tokenizer = load_model(shared.previous_model_name) def unload_model_if_idle():