diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 44c1df86..646dee2d 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -319,7 +319,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p yield {'prompt': prompt} return - token_count = len(encode(prompt)[0]) debug_msg({'prompt': prompt, 'generate_params': generate_params}) if stream: @@ -330,7 +329,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p answer = '' seen_content = '' - completion_token_count = 0 for a in generator: answer = a['internal'][-1][1] @@ -345,6 +343,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p chunk = chat_streaming_chunk(new_content) yield chunk + token_count = len(encode(prompt)[0]) completion_token_count = len(encode(answer)[0]) stop_reason = "stop" if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: @@ -429,8 +428,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): prompt = decode(prompt)[0] prefix = prompt if echo else '' - token_count = len(encode(prompt)[0]) - total_prompt_token_count += token_count # generate reply ####################################### debug_msg({'prompt': prompt, 'generate_params': generate_params}) @@ -440,6 +437,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False): for a in generator: answer = a + token_count = len(encode(prompt)[0]) + total_prompt_token_count += token_count completion_token_count = len(encode(answer)[0]) total_completion_token_count += completion_token_count stop_reason = "stop"