Rename custom_stopping_strings in the api extension

This commit is contained in:
oobabooga 2023-04-20 00:14:19 -03:00
parent 7bb9036ac9
commit 96ba55501c

View file

@ -34,9 +34,7 @@ class Handler(BaseHTTPRequestHandler):
prompt = body['prompt'] prompt = body['prompt']
prompt_lines = [k.strip() for k in prompt.split('\n')] prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048) max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context: while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
prompt_lines.pop(0) prompt_lines.pop(0)
@ -58,17 +56,13 @@ class Handler(BaseHTTPRequestHandler):
'early_stopping': bool(body.get('early_stopping', False)), 'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)), 'seed': int(body.get('seed', -1)),
'add_bos_token': int(body.get('add_bos_token', True)), 'add_bos_token': int(body.get('add_bos_token', True)),
'custom_stopping_strings': body.get('custom_stopping_strings', []),
'truncation_length': int(body.get('truncation_length', 2048)), 'truncation_length': int(body.get('truncation_length', 2048)),
'ban_eos_token': bool(body.get('ban_eos_token', False)), 'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)), 'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'stopping_strings': body.get('stopping_strings', []),
} }
stopping_strings = generate_params.pop('stopping_strings')
generator = generate_reply( generator = generate_reply(prompt, generate_params, stopping_strings=stopping_strings)
prompt,
generate_params,
)
answer = '' answer = ''
for a in generator: for a in generator:
if isinstance(a, str): if isinstance(a, str):