From addcb52c5697fd612cd43bcb731cd806995b817c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 28 Jul 2024 18:31:40 -0700 Subject: [PATCH] Make --idle-timeout work for API requests --- extensions/openai/completions.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 44c1df86..646dee2d 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -319,7 +319,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p yield {'prompt': prompt} return - token_count = len(encode(prompt)[0]) debug_msg({'prompt': prompt, 'generate_params': generate_params}) if stream: @@ -330,7 +329,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p answer = '' seen_content = '' - completion_token_count = 0 for a in generator: answer = a['internal'][-1][1] @@ -345,6 +343,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p chunk = chat_streaming_chunk(new_content) yield chunk + token_count = len(encode(prompt)[0]) completion_token_count = len(encode(answer)[0]) stop_reason = "stop" if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: @@ -429,8 +428,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): prompt = decode(prompt)[0] prefix = prompt if echo else '' - token_count = len(encode(prompt)[0]) - total_prompt_token_count += token_count # generate reply ####################################### debug_msg({'prompt': prompt, 'generate_params': generate_params}) @@ -440,6 +437,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): for a in generator: answer = a + token_count = len(encode(prompt)[0]) + total_prompt_token_count += token_count completion_token_count = len(encode(answer)[0]) total_completion_token_count += completion_token_count stop_reason = "stop"