diff --git a/extensions/openai/models.py b/extensions/openai/models.py index 1ff950a2..8a093ebe 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -1,8 +1,9 @@ from modules import shared from modules.logging_colors import logger +from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model from modules.models_settings import get_model_metadata, update_model_parameters -from modules.utils import get_available_models +from modules.utils import get_available_loras, get_available_models def get_current_model_info(): @@ -13,12 +14,17 @@ def get_current_model_info(): def list_models(): + return {'model_names': get_available_models()[1:]} + + +def list_dummy_models(): result = { "object": "list", "data": [] } - for model in get_dummy_models() + get_available_models()[1:]: + # these are expected by so much, so include some here as a dummy + for model in ['gpt-3.5-turbo', 'text-embedding-ada-002']: result["data"].append(model_info_dict(model)) return result @@ -33,13 +39,6 @@ def model_info_dict(model_name: str) -> dict: } -def get_dummy_models() -> list: - return [ # these are expected by so much, so include some here as a dummy - 'gpt-3.5-turbo', - 'text-embedding-ada-002', - ] - - def _load_model(data): model_name = data["model_name"] args = data["args"] @@ -67,3 +66,15 @@ def _load_model(data): logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}") elif k == 'instruction_template': logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}") + + +def list_loras(): + return {'lora_names': get_available_loras()[1:]} + + +def load_loras(lora_names): + add_lora_to_model(lora_names) + + +def unload_all_loras(): + add_lora_to_model([]) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index a516b0f7..047c339a 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -38,10 +38,13 @@ from .typing import ( EmbeddingsResponse, EncodeRequest, EncodeResponse, + LoadLorasRequest, LoadModelRequest, LogitsRequest, LogitsResponse, + LoraListResponse, ModelInfoResponse, + ModelListResponse, TokenCountResponse, to_dict ) @@ -141,7 +144,7 @@ async def handle_models(request: Request): is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models' if is_list: - response = OAImodels.list_models() + response = OAImodels.list_dummy_models() else: model_name = path[len('/v1/models/'):] response = OAImodels.model_info_dict(model_name) @@ -267,6 +270,12 @@ async def handle_model_info(): return JSONResponse(content=payload) +@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key) +async def handle_list_models(): + payload = OAImodels.list_models() + return JSONResponse(content=payload) + + @app.post("/v1/internal/model/load", dependencies=check_admin_key) async def handle_load_model(request_data: LoadModelRequest): ''' @@ -307,6 +316,27 @@ async def handle_load_model(request_data: LoadModelRequest): @app.post("/v1/internal/model/unload", dependencies=check_admin_key) async def handle_unload_model(): unload_model() + + +@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key) +async def handle_list_loras(): + response = OAImodels.list_loras() + return JSONResponse(content=response) + + +@app.post("/v1/internal/lora/load", dependencies=check_admin_key) +async def handle_load_loras(request_data: LoadLorasRequest): + try: + OAImodels.load_loras(request_data.lora_names) + return JSONResponse(content="OK") + except: + traceback.print_exc() + return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).") + + +@app.post("/v1/internal/lora/unload", dependencies=check_admin_key) +async def handle_unload_loras(): + OAImodels.unload_all_loras() return JSONResponse(content="OK") diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 05d3f753..5a2d40d5 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -122,6 +122,19 @@ class ChatCompletionResponse(BaseModel): usage: dict +class EmbeddingsRequest(BaseModel): + input: str | List[str] + model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.") + encoding_format: str = Field(default="float", description="Can be float or base64.") + user: str | None = Field(default=None, description="Unused parameter.") + + +class EmbeddingsResponse(BaseModel): + index: int + embedding: List[float] + object: str = "embedding" + + class EncodeRequest(BaseModel): text: str @@ -166,23 +179,22 @@ class ModelInfoResponse(BaseModel): lora_names: List[str] +class ModelListResponse(BaseModel): + model_names: List[str] + + class LoadModelRequest(BaseModel): model_name: str args: dict | None = None settings: dict | None = None -class EmbeddingsRequest(BaseModel): - input: str | List[str] - model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.") - encoding_format: str = Field(default="float", description="Can be float or base64.") - user: str | None = Field(default=None, description="Unused parameter.") +class LoraListResponse(BaseModel): + lora_names: List[str] -class EmbeddingsResponse(BaseModel): - index: int - embedding: List[float] - object: str = "embedding" +class LoadLorasRequest(BaseModel): + lora_names: List[str] def to_json(obj):