diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 03d99e8d..0d294643 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -26,7 +26,7 @@ from extensions.openai.tokens import token_count, token_decode, token_encode from extensions.openai.utils import _start_cloudflared from modules import shared from modules.logging_colors import logger -from modules.models import unload_model +from modules.models import unload_model, load_last_model from modules.text_generation import stop_everything_event from .typing import ( @@ -325,6 +325,21 @@ async def handle_load_model(request_data: LoadModelRequest): return HTTPException(status_code=400, detail="Failed to load the model.") +@app.post("/v1/internal/model/loadlast", dependencies=check_admin_key) +async def handle_load_last_model(): + ''' + This endpoint is experimental and may change in the future. + + Loads the last model used before it was unloaded. + ''' + try: + load_last_model() + return JSONResponse(content="OK") + except: + traceback.print_exc() + return HTTPException(status_code=400, detail="Failed to load the last-used model.") + + @app.post("/v1/internal/model/unload", dependencies=check_admin_key) async def handle_unload_model(): unload_model() diff --git a/modules/models.py b/modules/models.py index 07c14308..5c16376e 100644 --- a/modules/models.py +++ b/modules/models.py @@ -396,9 +396,13 @@ def unload_model(): clear_torch_cache() +def load_last_model(): + shared.model, shared.tokenizer = load_model(shared.previous_model_name) + + def reload_model(): unload_model() - shared.model, shared.tokenizer = load_model(shared.model_name) + shared.model, shared.tokenizer = load_model(shared.previous_model_name) def unload_model_if_idle():