Add a /v1/internal/chat-prompt endpoint (#5879)

This commit is contained in:
oobabooga 2024-04-19 00:24:46 -03:00 committed by GitHub
parent b30bce3b2f
commit f27e1ba302
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: B5690EEEBB952194
5 changed files with 35 additions and 11 deletions

View file

@ -135,6 +135,7 @@ def convert_history(history):
current_message = "" current_message = ""
current_reply = "" current_reply = ""
user_input = "" user_input = ""
user_input_last = True
system_message = "" system_message = ""
# Multimodal: convert OpenAI format to multimodal extension format # Multimodal: convert OpenAI format to multimodal extension format
@ -188,6 +189,7 @@ def convert_history(history):
if role == "user": if role == "user":
user_input = content user_input = content
user_input_last = True
if current_message: if current_message:
chat_dialogue.append([current_message, '']) chat_dialogue.append([current_message, ''])
current_message = "" current_message = ""
@ -195,6 +197,7 @@ def convert_history(history):
current_message = content current_message = content
elif role == "assistant": elif role == "assistant":
current_reply = content current_reply = content
user_input_last = False
if current_message: if current_message:
chat_dialogue.append([current_message, current_reply]) chat_dialogue.append([current_message, current_reply])
current_message = "" current_message = ""
@ -204,13 +207,13 @@ def convert_history(history):
elif role == "system": elif role == "system":
system_message = content system_message = content
# if current_message: if not user_input_last:
# chat_dialogue.append([current_message, '']) user_input = ""
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)} return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict: def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
if body.get('functions', []): if body.get('functions', []):
raise InvalidRequestError(message="functions is not supported.", param='functions') raise InvalidRequestError(message="functions is not supported.", param='functions')
@ -310,14 +313,18 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -
# chunk[resp_list][0]["logprobs"] = None # chunk[resp_list][0]["logprobs"] = None
return chunk return chunk
if stream:
yield chat_streaming_chunk('')
# generate reply ####################################### # generate reply #######################################
prompt = generate_chat_prompt(user_input, generate_params) prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_)
if prompt_only:
yield {'prompt': prompt}
return
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
debug_msg({'prompt': prompt, 'generate_params': generate_params}) debug_msg({'prompt': prompt, 'generate_params': generate_params})
if stream:
yield chat_streaming_chunk('')
generator = generate_chat_reply( generator = generate_chat_reply(
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)

View file

@ -9,7 +9,8 @@ from modules.utils import get_available_loras, get_available_models
def get_current_model_info(): def get_current_model_info():
return { return {
'model_name': shared.model_name, 'model_name': shared.model_name,
'lora_names': shared.lora_names 'lora_names': shared.lora_names,
'loader': shared.args.loader
} }

View file

@ -3,6 +3,7 @@ import json
import logging import logging
import os import os
import traceback import traceback
from collections import deque
from threading import Thread from threading import Thread
import speech_recognition as sr import speech_recognition as sr
@ -31,6 +32,7 @@ from modules.text_generation import stop_everything_event
from .typing import ( from .typing import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatPromptResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
DecodeRequest, DecodeRequest,
@ -259,6 +261,15 @@ async def handle_logits(request_data: LogitsRequest):
return JSONResponse(response) return JSONResponse(response)
@app.post('/v1/internal/chat-prompt', response_model=ChatPromptResponse, dependencies=check_key)
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
response = deque(generator, maxlen=1).pop()
return JSONResponse(response)
@app.post("/v1/internal/stop-generation", dependencies=check_key) @app.post("/v1/internal/stop-generation", dependencies=check_key)
async def handle_stop_generation(request: Request): async def handle_stop_generation(request: Request):
stop_everything_event() stop_everything_event()

View file

@ -124,6 +124,10 @@ class ChatCompletionResponse(BaseModel):
usage: dict usage: dict
class ChatPromptResponse(BaseModel):
prompt: str
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
input: str | List[str] | List[int] | List[List[int]] input: str | List[str] | List[int] | List[List[int]]
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.") model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")

View file

@ -136,9 +136,6 @@ def get_model_metadata(model):
if 'instruction_template' not in model_settings: if 'instruction_template' not in model_settings:
model_settings['instruction_template'] = 'Alpaca' model_settings['instruction_template'] = 'Alpaca'
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
# Ignore rope_freq_base if set to the default value # Ignore rope_freq_base if set to the default value
if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000: if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000:
model_settings.pop('rope_freq_base') model_settings.pop('rope_freq_base')
@ -150,6 +147,10 @@ def get_model_metadata(model):
for k in settings[pat]: for k in settings[pat]:
model_settings[k] = settings[pat][k] model_settings[k] = settings[pat][k]
# Load instruction template if defined by name rather than by value
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
return model_settings return model_settings