API: better handle temperature = 0

This commit is contained in:
oobabooga 2024-01-22 04:12:23 -08:00
parent 817866c9cf
commit 6247eafcc5
2 changed files with 4 additions and 6 deletions

View file

@ -92,6 +92,10 @@ def process_parameters(body, is_legacy=False):
if generate_params['truncation_length'] == 0: if generate_params['truncation_length'] == 0:
generate_params['truncation_length'] = shared.settings['truncation_length'] generate_params['truncation_length'] = shared.settings['truncation_length']
if generate_params['temperature'] == 0:
generate_params['do_sample'] = False
generate_params['top_k'] = 1
if body['preset'] is not None: if body['preset'] is not None:
preset = load_preset_memoized(body['preset']) preset = load_preset_memoized(body['preset'])
generate_params.update(preset) generate_params.update(preset)

View file

@ -97,9 +97,6 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
path = request.url.path path = request.url.path
is_legacy = "/generate" in path is_legacy = "/generate" in path
if request_data.temperature == 0:
request_data.do_sample = False
if request_data.stream: if request_data.stream:
async def generator(): async def generator():
async with streaming_semaphore: async with streaming_semaphore:
@ -123,9 +120,6 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
path = request.url.path path = request.url.path
is_legacy = "/generate" in path is_legacy = "/generate" in path
if request_data.temperature == 0:
request_data.do_sample = False
if request_data.stream: if request_data.stream:
async def generator(): async def generator():
async with streaming_semaphore: async with streaming_semaphore: