Make --idle-timeout work for API requests

This commit is contained in:
oobabooga 2024-07-28 18:31:40 -07:00
parent 514fb2e451
commit addcb52c56

View file

@ -319,7 +319,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
yield {'prompt': prompt} yield {'prompt': prompt}
return return
token_count = len(encode(prompt)[0])
debug_msg({'prompt': prompt, 'generate_params': generate_params}) debug_msg({'prompt': prompt, 'generate_params': generate_params})
if stream: if stream:
@ -330,7 +329,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
answer = '' answer = ''
seen_content = '' seen_content = ''
completion_token_count = 0
for a in generator: for a in generator:
answer = a['internal'][-1][1] answer = a['internal'][-1][1]
@ -345,6 +343,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
chunk = chat_streaming_chunk(new_content) chunk = chat_streaming_chunk(new_content)
yield chunk yield chunk
token_count = len(encode(prompt)[0])
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
stop_reason = "stop" stop_reason = "stop"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
@ -429,8 +428,6 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
prompt = decode(prompt)[0] prompt = decode(prompt)[0]
prefix = prompt if echo else '' prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
# generate reply ####################################### # generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params}) debug_msg({'prompt': prompt, 'generate_params': generate_params})
@ -440,6 +437,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
for a in generator: for a in generator:
answer = a answer = a
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
total_completion_token_count += completion_token_count total_completion_token_count += completion_token_count
stop_reason = "stop" stop_reason = "stop"