Implement echo/suffix parameters

This commit is contained in:
oobabooga 2023-11-07 08:43:45 -08:00
parent cee099f131
commit 3d59346871
2 changed files with 8 additions and 6 deletions

View file

@ -349,8 +349,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
generate_params['stream'] = stream
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
# generate_params['suffix'] = body.get('suffix', generate_params['suffix'])
generate_params['echo'] = body.get('echo', generate_params['echo'])
suffix = body['suffix'] if body['suffix'] else ''
echo = body['echo']
if not stream:
prompt_arg = body[prompt_str]
@ -373,6 +373,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
except KeyError:
prompt = decode(prompt)[0]
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
total_prompt_token_count += token_count
@ -393,7 +394,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
respi = {
"index": idx,
"finish_reason": stop_reason,
"text": answer,
"text": prefix + answer + suffix,
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
}
@ -425,6 +426,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
else:
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
prefix = prompt if echo else ''
token_count = len(encode(prompt)[0])
def text_streaming_chunk(content):
@ -444,7 +446,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
return chunk
yield text_streaming_chunk('')
yield text_streaming_chunk(prefix)
# generate reply #######################################
debug_msg({'prompt': prompt, 'generate_params': generate_params})
@ -472,7 +474,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
chunk = text_streaming_chunk('')
chunk = text_streaming_chunk(suffix)
chunk[resp_list][0]["finish_reason"] = stop_reason
chunk["usage"] = {
"prompt_tokens": token_count,

View file

@ -57,7 +57,7 @@ class CompletionRequestParams(BaseModel):
suffix: str | None = None
temperature: float | None = 1
top_p: float | None = 1
user: str | None = None
user: str | None = Field(default=None, description="Unused parameter.")
class CompletionRequest(GenerationOptions, CompletionRequestParams):