Add /v1/internal/logits endpoint (#4650)

This commit is contained in:
oobabooga 2023-11-18 23:19:31 -03:00 committed by GitHub
parent 8f4f4daf8b
commit 0fa1af296c
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 9 deletions

View file

@ -97,6 +97,29 @@ curl http://127.0.0.1:5000/v1/chat/completions \
}'
```
#### Logits
```
curl -k http://127.0.0.1:5000/v1/internal/logits \
-H "Content-Type: application/json" \
-d '{
"prompt": "Who is best, Asuka or Rei? Answer:",
"use_samplers": false
}'
```
#### Logits after sampling parameters
```
curl -k http://127.0.0.1:5000/v1/internal/logits \
-H "Content-Type: application/json" \
-d '{
"prompt": "Who is best, Asuka or Rei? Answer:",
"use_samplers": true,
"top_k": 3
}'
```
#### Python chat example
```python

View file

@ -16,6 +16,7 @@ from sse_starlette import EventSourceResponse
import extensions.openai.completions as OAIcompletions
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
from extensions.openai.errors import ServiceUnavailableError
@ -38,6 +39,8 @@ from .typing import (
EncodeRequest,
EncodeResponse,
LoadModelRequest,
LogitsRequest,
LogitsResponse,
ModelInfoResponse,
TokenCountResponse,
to_dict
@ -242,6 +245,16 @@ async def handle_token_count(request_data: EncodeRequest):
return JSONResponse(response)
@app.post("/v1/internal/logits", response_model=LogitsResponse, dependencies=check_key)
async def handle_logits(request_data: LogitsRequest):
'''
Given a prompt, returns the top 50 most likely logits as a dict.
The keys are the tokens, and the values are the probabilities.
'''
response = OAIlogits._get_next_logits(to_dict(request_data))
return JSONResponse(response)
@app.post("/v1/internal/stop-generation", dependencies=check_key)
async def handle_stop_generation(request: Request):
stop_everything_event()

View file

@ -126,15 +126,15 @@ class EncodeRequest(BaseModel):
text: str
class DecodeRequest(BaseModel):
tokens: List[int]
class EncodeResponse(BaseModel):
tokens: List[int]
length: int
class DecodeRequest(BaseModel):
tokens: List[int]
class DecodeResponse(BaseModel):
text: str
@ -143,6 +143,24 @@ class TokenCountResponse(BaseModel):
length: int
class LogitsRequestParams(BaseModel):
prompt: str
use_samplers: bool = False
frequency_penalty: float | None = 0
max_tokens: int | None = 16
presence_penalty: float | None = 0
temperature: float | None = 1
top_p: float | None = 1
class LogitsRequest(GenerationOptions, LogitsRequestParams):
pass
class LogitsResponse(BaseModel):
logits: dict
class ModelInfoResponse(BaseModel):
model_name: str
lora_names: List[str]

View file

@ -8,7 +8,7 @@ from modules.text_generation import generate_reply
global_scores = None
def get_next_logits(prompt, state, use_samplers, previous):
def get_next_logits(prompt, state, use_samplers, previous, return_dict=False):
if shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous
@ -56,8 +56,16 @@ def get_next_logits(prompt, state, use_samplers, previous):
topk_indices = [i.expand((1, 1)) for i in topk_indices]
tokens = [shared.tokenizer.decode(i) for i in topk_indices]
output = ''
for row in list(zip(topk_values, tokens)):
output += f"{row[0]} - {repr(row[1])}\n"
return output, previous
if return_dict:
output = {}
for row in list(zip(topk_values, tokens)):
output[row[1]] = row[0]
return output
else:
output = ''
for row in list(zip(topk_values, tokens)):
output += f"{row[0]} - {repr(row[1])}\n"
return output, previous