diff --git a/docs/12 - OpenAI API.md b/docs/12 - OpenAI API.md index 05b4db02..abbd432d 100644 --- a/docs/12 - OpenAI API.md +++ b/docs/12 - OpenAI API.md @@ -97,6 +97,29 @@ curl http://127.0.0.1:5000/v1/chat/completions \ }' ``` +#### Logits + +``` +curl -k http://127.0.0.1:5000/v1/internal/logits \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "Who is best, Asuka or Rei? Answer:", + "use_samplers": false + }' +``` + +#### Logits after sampling parameters + +``` +curl -k http://127.0.0.1:5000/v1/internal/logits \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "Who is best, Asuka or Rei? Answer:", + "use_samplers": true, + "top_k": 3 + }' +``` + #### Python chat example ```python diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 43d4b261..da56287c 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -16,6 +16,7 @@ from sse_starlette import EventSourceResponse import extensions.openai.completions as OAIcompletions import extensions.openai.embeddings as OAIembeddings import extensions.openai.images as OAIimages +import extensions.openai.logits as OAIlogits import extensions.openai.models as OAImodels import extensions.openai.moderations as OAImoderations from extensions.openai.errors import ServiceUnavailableError @@ -38,6 +39,8 @@ from .typing import ( EncodeRequest, EncodeResponse, LoadModelRequest, + LogitsRequest, + LogitsResponse, ModelInfoResponse, TokenCountResponse, to_dict @@ -242,6 +245,16 @@ async def handle_token_count(request_data: EncodeRequest): return JSONResponse(response) +@app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key) +async def handle_logits(request_data: LogitsRequest): + ''' + Given a prompt, returns the top 50 most likely logits as a dict. + The keys are the tokens, and the values are the probabilities. + ''' + response = OAIlogits._get_next_logits(to_dict(request_data)) + return JSONResponse(response) + + @app.post("/v1/internal/stop-generation", dependencies=check_key) async def handle_stop_generation(request: Request): stop_everything_event() diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index ee8f2ac6..05d3f753 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -126,15 +126,15 @@ class EncodeRequest(BaseModel): text: str -class DecodeRequest(BaseModel): - tokens: List[int] - - class EncodeResponse(BaseModel): tokens: List[int] length: int +class DecodeRequest(BaseModel): + tokens: List[int] + + class DecodeResponse(BaseModel): text: str @@ -143,6 +143,24 @@ class TokenCountResponse(BaseModel): length: int +class LogitsRequestParams(BaseModel): + prompt: str + use_samplers: bool = False + frequency_penalty: float | None = 0 + max_tokens: int | None = 16 + presence_penalty: float | None = 0 + temperature: float | None = 1 + top_p: float | None = 1 + + +class LogitsRequest(GenerationOptions, LogitsRequestParams): + pass + + +class LogitsResponse(BaseModel): + logits: dict + + class ModelInfoResponse(BaseModel): model_name: str lora_names: List[str] diff --git a/modules/logits.py b/modules/logits.py index e356a986..383659e0 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -8,7 +8,7 @@ from modules.text_generation import generate_reply global_scores = None -def get_next_logits(prompt, state, use_samplers, previous): +def get_next_logits(prompt, state, use_samplers, previous, return_dict=False): if shared.model is None: logger.error("No model is loaded! Select one in the Model tab.") return 'Error: No model is loaded1 Select one in the Model tab.', previous @@ -56,8 +56,16 @@ def get_next_logits(prompt, state, use_samplers, previous): topk_indices = [i.expand((1, 1)) for i in topk_indices] tokens = [shared.tokenizer.decode(i) for i in topk_indices] - output = '' - for row in list(zip(topk_values, tokens)): - output += f"{row[0]} - {repr(row[1])}\n" - return output, previous + if return_dict: + output = {} + for row in list(zip(topk_values, tokens)): + output[row[1]] = row[0] + + return output + else: + output = '' + for row in list(zip(topk_values, tokens)): + output += f"{row[0]} - {repr(row[1])}\n" + + return output, previous