Style/pep8 improvements

This commit is contained in:
oobabooga 2023-05-02 23:05:38 -03:00
parent ecd79caa68
commit 320fcfde4e

View file

@ -1,8 +1,9 @@
import json, time, os import json
import os
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from modules import shared from modules import shared
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
@ -25,13 +26,15 @@ embedding_model = None
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
# little helper to get defaults if arg is present but None and should be the same type as default. # little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default): def default(dic, key, default):
val = dic.get(key, default) val = dic.get(key, default)
if type(val) != type(default): if type(val) != type(default):
# maybe it's just something like 1 instead of 1.0 # maybe it's just something like 1 instead of 1.0
try: try:
v = type(default)(val) v = type(default)(val)
if type(val)(v) == val: # if it's the same value passed in, it's ok. if type(val)(v) == val: # if it's the same value passed in, it's ok.
return v return v
except: except:
pass pass
@ -39,6 +42,7 @@ def default(dic, key, default):
val = default val = default
return val return val
def clamp(value, minvalue, maxvalue): def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue)) return max(minvalue, min(value, maxvalue))
@ -54,27 +58,27 @@ class Handler(BaseHTTPRequestHandler):
# TODO: list all models and allow model changes via API? Lora's? # TODO: list all models and allow model changes via API? Lora's?
# This API should list capabilities, limits and pricing... # This API should list capabilities, limits and pricing...
models = [{ models = [{
"id": shared.model_name, # The real chat/completions model "id": shared.model_name, # The real chat/completions model
"object": "model", "object": "model",
"owned_by": "user", "owned_by": "user",
"permission": [] "permission": []
}, { }, {
"id": st_model, # The real sentence transformer embeddings model "id": st_model, # The real sentence transformer embeddings model
"object": "model", "object": "model",
"owned_by": "user", "owned_by": "user",
"permission": [] "permission": []
}, { # these are expected by so much, so include some here as a dummy }, { # these are expected by so much, so include some here as a dummy
"id": "gpt-3.5-turbo", # /v1/chat/completions "id": "gpt-3.5-turbo", # /v1/chat/completions
"object": "model", "object": "model",
"owned_by": "user", "owned_by": "user",
"permission": [] "permission": []
}, { }, {
"id": "text-curie-001", # /v1/completions, 2k context "id": "text-curie-001", # /v1/completions, 2k context
"object": "model", "object": "model",
"owned_by": "user", "owned_by": "user",
"permission": [] "permission": []
}, { }, {
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768 "id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
"object": "model", "object": "model",
"owned_by": "user", "owned_by": "user",
"permission": [] "permission": []
@ -103,8 +107,10 @@ class Handler(BaseHTTPRequestHandler):
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8')) body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version? if debug:
if debug: print(body) print(self.headers) # did you know... python-openai sends your linux kernel & python version?
if debug:
print(body)
if '/completions' in self.path or '/generate' in self.path: if '/completions' in self.path or '/generate' in self.path:
is_legacy = '/generate' in self.path is_legacy = '/generate' in self.path
@ -112,7 +118,7 @@ class Handler(BaseHTTPRequestHandler):
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
# XXX model is ignored for now # XXX model is ignored for now
#model = body.get('model', shared.model_name) # ignored, use existing for now # model = body.get('model', shared.model_name) # ignored, use existing for now
model = shared.model_name model = shared.model_name
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "conv-%d" % (created_time) cmpl_id = "conv-%d" % (created_time)
@ -129,11 +135,11 @@ class Handler(BaseHTTPRequestHandler):
truncation_length = default(shared.settings, 'truncation_length', 2048) truncation_length = default(shared.settings, 'truncation_length', 2048)
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length) truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it., the default for chat is "inf" default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it., the default for chat is "inf"
max_tokens_str = 'length' if is_legacy else 'max_tokens' max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens)) max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
# hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max # hard scale this, assuming the given max is for GPT3/4, perhaps inspect the requested model and lookup the context max
while truncation_length <= max_tokens: while truncation_length <= max_tokens:
max_tokens = max_tokens // 2 max_tokens = max_tokens // 2
@ -143,17 +149,17 @@ class Handler(BaseHTTPRequestHandler):
'temperature': default(body, 'temperature', 1.0), 'temperature': default(body, 'temperature', 1.0),
'top_p': default(body, 'top_p', 1.0), 'top_p': default(body, 'top_p', 1.0),
'top_k': default(body, 'best_of', 1), 'top_k': default(body, 'best_of', 1),
### XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2 # XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it. # 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better. 'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
### XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it. # XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
'encoder_repetition_penalty': 1.0, #(default(body, 'frequency_penalty', 0) + 2.0) / 2.0, 'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
'suffix': body.get('suffix', None), 'suffix': body.get('suffix', None),
'stream': default(body, 'stream', False), 'stream': default(body, 'stream', False),
'echo': default(body, 'echo', False), 'echo': default(body, 'echo', False),
##################################################### #####################################################
'seed': shared.settings.get('seed', -1), 'seed': shared.settings.get('seed', -1),
#int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map # int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
# unofficial, but it needs to get set anyways. # unofficial, but it needs to get set anyways.
'truncation_length': truncation_length, 'truncation_length': truncation_length,
# no more args. # no more args.
@ -178,7 +184,7 @@ class Handler(BaseHTTPRequestHandler):
if req_params['stream']: if req_params['stream']:
self.send_header('Content-Type', 'text/event-stream') self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache') self.send_header('Cache-Control', 'no-cache')
#self.send_header('Connection', 'keep-alive') # self.send_header('Connection', 'keep-alive')
else: else:
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')
self.end_headers() self.end_headers()
@ -195,8 +201,8 @@ class Handler(BaseHTTPRequestHandler):
messages = body['messages'] messages = body['messages']
system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this. if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this.
system_msg = body['prompt'] system_msg = body['prompt']
chat_msgs = [] chat_msgs = []
@ -204,16 +210,16 @@ class Handler(BaseHTTPRequestHandler):
for m in messages: for m in messages:
role = m['role'] role = m['role']
content = m['content'] content = m['content']
#name = m.get('name', 'user') # name = m.get('name', 'user')
if role == 'system': if role == 'system':
system_msg += content system_msg += content
else: else:
chat_msgs.extend([f"\n{role}: {content.strip()}"]) ### Strip content? linefeed? chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed?
system_token_count = len(encode(system_msg)[0]) system_token_count = len(encode(system_msg)[0])
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
chat_msg = '' chat_msg = ''
while chat_msgs: while chat_msgs:
new_msg = chat_msgs.pop() new_msg = chat_msgs.pop()
new_size = len(encode(new_msg)[0]) new_size = len(encode(new_msg)[0])
@ -229,7 +235,7 @@ class Handler(BaseHTTPRequestHandler):
print(f"truncating chat messages, dropping {len(chat_msgs)} messages.") print(f"truncating chat messages, dropping {len(chat_msgs)} messages.")
if system_msg: if system_msg:
prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: ' prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: '
else: else:
prompt = chat_msg + '\nassistant: ' prompt = chat_msg + '\nassistant: '
@ -245,16 +251,16 @@ class Handler(BaseHTTPRequestHandler):
# ... encoded as a string, array of strings, array of tokens, or array of token arrays. # ... encoded as a string, array of strings, array of tokens, or array of token arrays.
if is_legacy: if is_legacy:
prompt = body['context'] # Older engines.generate API prompt = body['context'] # Older engines.generate API
else: else:
prompt = body['prompt'] # XXX this can be different types prompt = body['prompt'] # XXX this can be different types
if isinstance(prompt, list): if isinstance(prompt, list):
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls? prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls?
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
if token_count >= req_params['truncation_length']: if token_count >= req_params['truncation_length']:
new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count) new_len = int(len(prompt) * (float(shared.settings['truncation_length']) - req_params['max_new_tokens']) / token_count)
prompt = prompt[-new_len:] prompt = prompt[-new_len:]
print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.") print(f"truncating prompt to {new_len} characters, was {token_count} tokens. Now: {len(encode(prompt)[0])} tokens.")
@ -262,7 +268,6 @@ class Handler(BaseHTTPRequestHandler):
# some strange cases of "##| Instruction: " sneaking through. # some strange cases of "##| Instruction: " sneaking through.
stopping_strings += standard_stopping_strings stopping_strings += standard_stopping_strings
req_params['custom_stopping_strings'] = stopping_strings req_params['custom_stopping_strings'] = stopping_strings
shared.args.no_stream = not req_params['stream'] shared.args.no_stream = not req_params['stream']
if not shared.args.no_stream: if not shared.args.no_stream:
@ -283,22 +288,23 @@ class Handler(BaseHTTPRequestHandler):
chunk[resp_list][0]["text"] = "" chunk[resp_list][0]["text"] = ""
else: else:
# This is coming back as "system" to the openapi cli, not sure why. # This is coming back as "system" to the openapi cli, not sure why.
# So yeah... do both methods? delta and messages. # So yeah... do both methods? delta and messages.
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
#{ "role": "assistant" } # { "role": "assistant" }
response = 'data: ' + json.dumps(chunk) + '\n' response = 'data: ' + json.dumps(chunk) + '\n'
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
# generate reply ####################################### # generate reply #######################################
if debug: print ({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) if debug:
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings)
answer = '' answer = ''
seen_content = '' seen_content = ''
longest_stop_len = max([ len(x) for x in stopping_strings ]) longest_stop_len = max([len(x) for x in stopping_strings])
for a in generator: for a in generator:
if isinstance(a, str): if isinstance(a, str):
answer = a answer = a
@ -312,7 +318,7 @@ class Handler(BaseHTTPRequestHandler):
for string in stopping_strings: for string in stopping_strings:
idx = answer.find(string, search_start) idx = answer.find(string, search_start)
if idx != -1: if idx != -1:
answer = answer[:idx] # clip it. answer = answer[:idx] # clip it.
stop_string_found = True stop_string_found = True
if stop_string_found: if stop_string_found:
@ -338,9 +344,9 @@ class Handler(BaseHTTPRequestHandler):
# Streaming # Streaming
new_content = answer[len_seen:] new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue continue
seen_content = answer seen_content = answer
chunk = { chunk = {
"id": cmpl_id, "id": cmpl_id,
@ -355,9 +361,9 @@ class Handler(BaseHTTPRequestHandler):
if stream_object_type == 'text_completion.chunk': if stream_object_type == 'text_completion.chunk':
chunk[resp_list][0]['text'] = new_content chunk[resp_list][0]['text'] = new_content
else: else:
# So yeah... do both methods? delta and messages. # So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = { 'content': new_content } chunk[resp_list][0]['message'] = {'content': new_content}
chunk[resp_list][0]['delta'] = { 'content': new_content } chunk[resp_list][0]['delta'] = {'content': new_content}
response = 'data: ' + json.dumps(chunk) + '\n' response = 'data: ' + json.dumps(chunk) + '\n'
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
completion_token_count += len(encode(new_content)[0]) completion_token_count += len(encode(new_content)[0])
@ -367,7 +373,7 @@ class Handler(BaseHTTPRequestHandler):
"id": cmpl_id, "id": cmpl_id,
"object": stream_object_type, "object": stream_object_type,
"created": created_time, "created": created_time,
"model": model, # TODO: add Lora info? "model": model, # TODO: add Lora info?
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": "stop", "finish_reason": "stop",
@ -381,16 +387,18 @@ class Handler(BaseHTTPRequestHandler):
if stream_object_type == 'text_completion.chunk': if stream_object_type == 'text_completion.chunk':
chunk[resp_list][0]['text'] = '' chunk[resp_list][0]['text'] = ''
else: else:
# So yeah... do both methods? delta and messages. # So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = {'content': '' } chunk[resp_list][0]['message'] = {'content': ''}
chunk[resp_list][0]['delta'] = {} chunk[resp_list][0]['delta'] = {}
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n' response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n'
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
###### Finished if streaming. # Finished if streaming.
if debug: print({'response': answer}) if debug:
print({'response': answer})
return return
if debug: print({'response': answer}) if debug:
print({'response': answer})
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
stop_reason = "stop" stop_reason = "stop"
@ -401,7 +409,7 @@ class Handler(BaseHTTPRequestHandler):
"id": cmpl_id, "id": cmpl_id,
"object": object_type, "object": object_type,
"created": created_time, "created": created_time,
"model": model, # TODO: add Lora info? "model": model, # TODO: add Lora info?
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": stop_reason, "finish_reason": stop_reason,
@ -414,13 +422,13 @@ class Handler(BaseHTTPRequestHandler):
} }
if is_chat: if is_chat:
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer } resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
else: else:
resp[resp_list][0]["text"] = answer resp[resp_list][0]["text"] = answer
response = json.dumps(resp) response = json.dumps(resp)
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/embeddings' in self.path and embedding_model != None: elif '/embeddings' in self.path and embedding_model is not None:
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')
self.end_headers() self.end_headers()
@ -431,19 +439,20 @@ class Handler(BaseHTTPRequestHandler):
embeddings = embedding_model.encode(input).tolist() embeddings = embedding_model.encode(input).tolist()
data = [ {"object": "embedding", "embedding": emb, "index": n } for n, emb in enumerate(embeddings) ] data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
response = json.dumps({ response = json.dumps({
"object": "list", "object": "list",
"data": data, "data": data,
"model": st_model, # return the real model "model": st_model, # return the real model
"usage": { "usage": {
"prompt_tokens": 0, "prompt_tokens": 0,
"total_tokens": 0, "total_tokens": 0,
} }
}) })
if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") if debug:
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/moderations' in self.path: elif '/moderations' in self.path:
# for now do nothing, just don't error. # for now do nothing, just don't error.
@ -521,4 +530,3 @@ def run_server():
def setup(): def setup():
Thread(target=run_server, daemon=True).start() Thread(target=run_server, daemon=True).start()