From e53f99faa0451dcf781939eb86a42255485a7ca7 Mon Sep 17 00:00:00 2001 From: Kim Jaewon <101622378+kimjaewon96@users.noreply.github.com> Date: Fri, 15 Dec 2023 12:22:43 +0900 Subject: [PATCH] [OpenAI Extension] Add 'max_logits' parameter in logits endpoint (#4916) --- extensions/openai/logits.py | 2 +- extensions/openai/typing.py | 5 +++-- modules/logits.py | 7 ++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/extensions/openai/logits.py b/extensions/openai/logits.py index 9d2fe41c..357e70fa 100644 --- a/extensions/openai/logits.py +++ b/extensions/openai/logits.py @@ -8,4 +8,4 @@ def _get_next_logits(body): state = process_parameters(body) if use_samplers else {} state['stream'] = True - return get_next_logits(body['prompt'], state, use_samplers, "", return_dict=True) + return get_next_logits(body['prompt'], state, use_samplers, "", top_logits=body['top_logits'], return_dict=True) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 47ddd789..332a8c28 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -1,6 +1,6 @@ import json import time -from typing import List +from typing import Dict, List from pydantic import BaseModel, Field @@ -156,6 +156,7 @@ class TokenCountResponse(BaseModel): class LogitsRequestParams(BaseModel): prompt: str use_samplers: bool = False + top_logits: int | None = 50 frequency_penalty: float | None = 0 max_tokens: int | None = 16 presence_penalty: float | None = 0 @@ -168,7 +169,7 @@ class LogitsRequest(GenerationOptions, LogitsRequestParams): class LogitsResponse(BaseModel): - logits: dict + logits: Dict[str, float] class ModelInfoResponse(BaseModel): diff --git a/modules/logits.py b/modules/logits.py index 5d0d3210..e12cf6e7 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -8,7 +8,7 @@ from modules.text_generation import generate_reply global_scores = None -def get_next_logits(prompt, state, use_samplers, previous, return_dict=False): +def get_next_logits(prompt, state, use_samplers, previous, top_logits=50, return_dict=False): if shared.model is None: logger.error("No model is loaded! Select one in the Model tab.") return 'Error: No model is loaded1 Select one in the Model tab.', previous @@ -50,8 +50,7 @@ def get_next_logits(prompt, state, use_samplers, previous, return_dict=False): scores = output['logits'][-1][-1] probs = torch.softmax(scores, dim=-1, dtype=torch.float) - topk_values, topk_indices = torch.topk(probs, k=50, largest=True, sorted=True) - topk_values = [f"{float(i):.5f}" for i in topk_values] + topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True) if is_non_hf_exllamav1 or is_non_hf_llamacpp: topk_indices = [i.expand((1, 1)) for i in topk_indices] @@ -61,12 +60,14 @@ def get_next_logits(prompt, state, use_samplers, previous, return_dict=False): tokens = [shared.tokenizer.decode(i) for i in topk_indices] if return_dict: + topk_values = [float(i) for i in topk_values] output = {} for row in list(zip(topk_values, tokens)): output[row[1]] = row[0] return output else: + topk_values = [f"{float(i):.5f}" for i in topk_values] output = '' for row in list(zip(topk_values, tokens)): output += f"{row[0]} - {repr(row[1])}\n"