Style/pep8 improvements

This commit is contained in:
oobabooga 2023-05-02 23:05:38 -03:00
parent ecd79caa68
commit 320fcfde4e

View file

@ -1,8 +1,9 @@
import json, time, os import json
import os
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from modules import shared from modules import shared
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
@ -25,6 +26,8 @@ embedding_model = None
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ] standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
# little helper to get defaults if arg is present but None and should be the same type as default. # little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default): def default(dic, key, default):
val = dic.get(key, default) val = dic.get(key, default)
if type(val) != type(default): if type(val) != type(default):
@ -39,6 +42,7 @@ def default(dic, key, default):
val = default val = default
return val return val
def clamp(value, minvalue, maxvalue): def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue)) return max(minvalue, min(value, maxvalue))
@ -103,8 +107,10 @@ class Handler(BaseHTTPRequestHandler):
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8')) body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if debug: print(self.headers) # did you know... python-openai sends your linux kernel & python version? if debug:
if debug: print(body) print(self.headers) # did you know... python-openai sends your linux kernel & python version?
if debug:
print(body)
if '/completions' in self.path or '/generate' in self.path: if '/completions' in self.path or '/generate' in self.path:
is_legacy = '/generate' in self.path is_legacy = '/generate' in self.path
@ -112,7 +118,7 @@ class Handler(BaseHTTPRequestHandler):
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
# XXX model is ignored for now # XXX model is ignored for now
#model = body.get('model', shared.model_name) # ignored, use existing for now # model = body.get('model', shared.model_name) # ignored, use existing for now
model = shared.model_name model = shared.model_name
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "conv-%d" % (created_time) cmpl_id = "conv-%d" % (created_time)
@ -143,17 +149,17 @@ class Handler(BaseHTTPRequestHandler):
'temperature': default(body, 'temperature', 1.0), 'temperature': default(body, 'temperature', 1.0),
'top_p': default(body, 'top_p', 1.0), 'top_p': default(body, 'top_p', 1.0),
'top_k': default(body, 'best_of', 1), 'top_k': default(body, 'best_of', 1),
### XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2 # XXX not sure about this one, seems to be the right mapping, but the range is different (-2..2.0) vs 0..2
# 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it. # 0 is default in openai, but 1.0 is default in other places. Maybe it's scaled? scale it.
'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better. 'repetition_penalty': 1.18, # (default(body, 'presence_penalty', 0) + 2.0 ) / 2.0, # 0 the real default, 1.2 is the model default, but 1.18 works better.
### XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it. # XXX not sure about this one either, same questions. (-2..2.0), 0 is default not 1.0, scale it.
'encoder_repetition_penalty': 1.0, #(default(body, 'frequency_penalty', 0) + 2.0) / 2.0, 'encoder_repetition_penalty': 1.0, # (default(body, 'frequency_penalty', 0) + 2.0) / 2.0,
'suffix': body.get('suffix', None), 'suffix': body.get('suffix', None),
'stream': default(body, 'stream', False), 'stream': default(body, 'stream', False),
'echo': default(body, 'echo', False), 'echo': default(body, 'echo', False),
##################################################### #####################################################
'seed': shared.settings.get('seed', -1), 'seed': shared.settings.get('seed', -1),
#int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map # int(body.get('n', 1)) # perhaps this should be num_beams or chat_generation_attempts? 'n' doesn't have a direct map
# unofficial, but it needs to get set anyways. # unofficial, but it needs to get set anyways.
'truncation_length': truncation_length, 'truncation_length': truncation_length,
# no more args. # no more args.
@ -178,7 +184,7 @@ class Handler(BaseHTTPRequestHandler):
if req_params['stream']: if req_params['stream']:
self.send_header('Content-Type', 'text/event-stream') self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache') self.send_header('Cache-Control', 'no-cache')
#self.send_header('Connection', 'keep-alive') # self.send_header('Connection', 'keep-alive')
else: else:
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')
self.end_headers() self.end_headers()
@ -204,11 +210,11 @@ class Handler(BaseHTTPRequestHandler):
for m in messages: for m in messages:
role = m['role'] role = m['role']
content = m['content'] content = m['content']
#name = m.get('name', 'user') # name = m.get('name', 'user')
if role == 'system': if role == 'system':
system_msg += content system_msg += content
else: else:
chat_msgs.extend([f"\n{role}: {content.strip()}"]) ### Strip content? linefeed? chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed?
system_token_count = len(encode(system_msg)[0]) system_token_count = len(encode(system_msg)[0])
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
@ -263,7 +269,6 @@ class Handler(BaseHTTPRequestHandler):
stopping_strings += standard_stopping_strings stopping_strings += standard_stopping_strings
req_params['custom_stopping_strings'] = stopping_strings req_params['custom_stopping_strings'] = stopping_strings
shared.args.no_stream = not req_params['stream'] shared.args.no_stream = not req_params['stream']
if not shared.args.no_stream: if not shared.args.no_stream:
shared.args.chat = True shared.args.chat = True
@ -286,18 +291,19 @@ class Handler(BaseHTTPRequestHandler):
# So yeah... do both methods? delta and messages. # So yeah... do both methods? delta and messages.
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''} chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
#{ "role": "assistant" } # { "role": "assistant" }
response = 'data: ' + json.dumps(chunk) + '\n' response = 'data: ' + json.dumps(chunk) + '\n'
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
# generate reply ####################################### # generate reply #######################################
if debug: print ({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) if debug:
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings)
answer = '' answer = ''
seen_content = '' seen_content = ''
longest_stop_len = max([ len(x) for x in stopping_strings ]) longest_stop_len = max([len(x) for x in stopping_strings])
for a in generator: for a in generator:
if isinstance(a, str): if isinstance(a, str):
@ -356,8 +362,8 @@ class Handler(BaseHTTPRequestHandler):
chunk[resp_list][0]['text'] = new_content chunk[resp_list][0]['text'] = new_content
else: else:
# So yeah... do both methods? delta and messages. # So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = { 'content': new_content } chunk[resp_list][0]['message'] = {'content': new_content}
chunk[resp_list][0]['delta'] = { 'content': new_content } chunk[resp_list][0]['delta'] = {'content': new_content}
response = 'data: ' + json.dumps(chunk) + '\n' response = 'data: ' + json.dumps(chunk) + '\n'
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
completion_token_count += len(encode(new_content)[0]) completion_token_count += len(encode(new_content)[0])
@ -382,15 +388,17 @@ class Handler(BaseHTTPRequestHandler):
chunk[resp_list][0]['text'] = '' chunk[resp_list][0]['text'] = ''
else: else:
# So yeah... do both methods? delta and messages. # So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = {'content': '' } chunk[resp_list][0]['message'] = {'content': ''}
chunk[resp_list][0]['delta'] = {} chunk[resp_list][0]['delta'] = {}
response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n' response = 'data: ' + json.dumps(chunk) + '\ndata: [DONE]\n'
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
###### Finished if streaming. # Finished if streaming.
if debug: print({'response': answer}) if debug:
print({'response': answer})
return return
if debug: print({'response': answer}) if debug:
print({'response': answer})
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
stop_reason = "stop" stop_reason = "stop"
@ -414,13 +422,13 @@ class Handler(BaseHTTPRequestHandler):
} }
if is_chat: if is_chat:
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer } resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
else: else:
resp[resp_list][0]["text"] = answer resp[resp_list][0]["text"] = answer
response = json.dumps(resp) response = json.dumps(resp)
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/embeddings' in self.path and embedding_model != None: elif '/embeddings' in self.path and embedding_model is not None:
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json') self.send_header('Content-Type', 'application/json')
self.end_headers() self.end_headers()
@ -431,7 +439,7 @@ class Handler(BaseHTTPRequestHandler):
embeddings = embedding_model.encode(input).tolist() embeddings = embedding_model.encode(input).tolist()
data = [ {"object": "embedding", "embedding": emb, "index": n } for n, emb in enumerate(embeddings) ] data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
response = json.dumps({ response = json.dumps({
"object": "list", "object": "list",
@ -443,7 +451,8 @@ class Handler(BaseHTTPRequestHandler):
} }
}) })
if debug: print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}") if debug:
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
elif '/moderations' in self.path: elif '/moderations' in self.path:
# for now do nothing, just don't error. # for now do nothing, just don't error.
@ -521,4 +530,3 @@ def run_server():
def setup(): def setup():
Thread(target=run_server, daemon=True).start() Thread(target=run_server, daemon=True).start()