From f27e1ba302971df80d2e06cfb3074d5004e769d3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 19 Apr 2024 00:24:46 -0300 Subject: [PATCH] Add a /v1/internal/chat-prompt endpoint (#5879) --- extensions/openai/completions.py | 21 ++++++++++++++------- extensions/openai/models.py | 3 ++- extensions/openai/script.py | 11 +++++++++++ extensions/openai/typing.py | 4 ++++ modules/models_settings.py | 7 ++++--- 5 files changed, 35 insertions(+), 11 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 5925101a..44c1df86 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -135,6 +135,7 @@ def convert_history(history): current_message = "" current_reply = "" user_input = "" + user_input_last = True system_message = "" # Multimodal: convert OpenAI format to multimodal extension format @@ -188,6 +189,7 @@ def convert_history(history): if role == "user": user_input = content + user_input_last = True if current_message: chat_dialogue.append([current_message, '']) current_message = "" @@ -195,6 +197,7 @@ def convert_history(history): current_message = content elif role == "assistant": current_reply = content + user_input_last = False if current_message: chat_dialogue.append([current_message, current_reply]) current_message = "" @@ -204,13 +207,13 @@ def convert_history(history): elif role == "system": system_message = content - # if current_message: - # chat_dialogue.append([current_message, '']) + if not user_input_last: + user_input = "" return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)} -def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict: +def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict: if body.get('functions', []): raise InvalidRequestError(message="functions is not supported.", param='functions') @@ -310,14 +313,18 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) - # chunk[resp_list][0]["logprobs"] = None return chunk - if stream: - yield chat_streaming_chunk('') - # generate reply ####################################### - prompt = generate_chat_prompt(user_input, generate_params) + prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_) + if prompt_only: + yield {'prompt': prompt} + return + token_count = len(encode(prompt)[0]) debug_msg({'prompt': prompt, 'generate_params': generate_params}) + if stream: + yield chat_streaming_chunk('') + generator = generate_chat_reply( user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) diff --git a/extensions/openai/models.py b/extensions/openai/models.py index 01045f90..a7e67df6 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -9,7 +9,8 @@ from modules.utils import get_available_loras, get_available_models def get_current_model_info(): return { 'model_name': shared.model_name, - 'lora_names': shared.lora_names + 'lora_names': shared.lora_names, + 'loader': shared.args.loader } diff --git a/extensions/openai/script.py b/extensions/openai/script.py index e8647357..03d99e8d 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -3,6 +3,7 @@ import json import logging import os import traceback +from collections import deque from threading import Thread import speech_recognition as sr @@ -31,6 +32,7 @@ from modules.text_generation import stop_everything_event from .typing import ( ChatCompletionRequest, ChatCompletionResponse, + ChatPromptResponse, CompletionRequest, CompletionResponse, DecodeRequest, @@ -259,6 +261,15 @@ async def handle_logits(request_data: LogitsRequest): return JSONResponse(response) +@app.post('/v1/internal/chat-prompt', response_model=ChatPromptResponse, dependencies=check_key) +async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest): + path = request.url.path + is_legacy = "/generate" in path + generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True) + response = deque(generator, maxlen=1).pop() + return JSONResponse(response) + + @app.post("/v1/internal/stop-generation", dependencies=check_key) async def handle_stop_generation(request: Request): stop_everything_event() diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index c3ef0404..2b30ebf2 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -124,6 +124,10 @@ class ChatCompletionResponse(BaseModel): usage: dict +class ChatPromptResponse(BaseModel): + prompt: str + + class EmbeddingsRequest(BaseModel): input: str | List[str] | List[int] | List[List[int]] model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.") diff --git a/modules/models_settings.py b/modules/models_settings.py index b7a7d332..5c292431 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -136,9 +136,6 @@ def get_model_metadata(model): if 'instruction_template' not in model_settings: model_settings['instruction_template'] = 'Alpaca' - if model_settings['instruction_template'] != 'Custom (obtained from model metadata)': - model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template']) - # Ignore rope_freq_base if set to the default value if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000: model_settings.pop('rope_freq_base') @@ -150,6 +147,10 @@ def get_model_metadata(model): for k in settings[pat]: model_settings[k] = settings[pat][k] + # Load instruction template if defined by name rather than by value + if model_settings['instruction_template'] != 'Custom (obtained from model metadata)': + model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template']) + return model_settings