From 6247eafcc5537cd82a74dd0dfaedc7db4a444159 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 22 Jan 2024 04:12:23 -0800 Subject: [PATCH] API: better handle temperature = 0 --- extensions/openai/completions.py | 4 ++++ extensions/openai/script.py | 6 ------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 19a8f893..aea07473 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -92,6 +92,10 @@ def process_parameters(body, is_legacy=False): if generate_params['truncation_length'] == 0: generate_params['truncation_length'] = shared.settings['truncation_length'] + if generate_params['temperature'] == 0: + generate_params['do_sample'] = False + generate_params['top_k'] = 1 + if body['preset'] is not None: preset = load_preset_memoized(body['preset']) generate_params.update(preset) diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 30f14c65..e8647357 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -97,9 +97,6 @@ async def openai_completions(request: Request, request_data: CompletionRequest): path = request.url.path is_legacy = "/generate" in path - if request_data.temperature == 0: - request_data.do_sample = False - if request_data.stream: async def generator(): async with streaming_semaphore: @@ -123,9 +120,6 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion path = request.url.path is_legacy = "/generate" in path - if request_data.temperature == 0: - request_data.do_sample = False - if request_data.stream: async def generator(): async with streaming_semaphore: