text-generation-webui/extensions/api/script.py

103 lines
3.6 KiB
Python
Raw Normal View History

2023-03-24 20:53:56 +01:00
import json
2023-03-15 21:52:46 +01:00
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
2023-03-24 20:53:56 +01:00
2023-03-15 21:52:46 +01:00
from modules import shared
2023-03-24 20:53:56 +01:00
from modules.text_generation import encode, generate_reply
2023-03-15 21:52:46 +01:00
params = {
'port': 5000,
}
2023-03-15 21:52:46 +01:00
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
prompt_lines = [k.strip() for k in prompt.split('\n')]
2023-03-15 21:52:46 +01:00
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_length', 200)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical', 1)),
'repetition_penalty': float(body.get('rep_pen', 1.1)),
'encoder_repetition_penalty': 1,
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
'num_beams': int(body.get('num_beams', 1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
2023-04-11 01:14:38 +02:00
'add_bos_token': int(body.get('add_bos_token', True)),
2023-04-11 18:06:51 +02:00
'custom_stopping_strings': body.get('custom_stopping_strings', []),
}
2023-03-15 21:52:46 +01:00
generator = generate_reply(
prompt,
generate_params,
2023-03-15 21:52:46 +01:00
)
answer = ''
for a in generator:
if isinstance(a, str):
answer = a
else:
answer = a[0]
2023-03-15 21:52:46 +01:00
response = json.dumps({
'results': [{
'text': answer[len(prompt):]
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def run_server():
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
server = ThreadingHTTPServer(server_addr, Handler)
if shared.args.share:
2023-03-16 00:00:16 +01:00
try:
from flask_cloudflared import _run_cloudflared
2023-03-16 00:00:16 +01:00
public_url = _run_cloudflared(params['port'], params['port'] + 1)
print(f'Starting KoboldAI compatible api at {public_url}/api')
except ImportError:
print('You should install flask_cloudflared manually')
else:
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
2023-03-15 21:52:46 +01:00
server.serve_forever()
2023-03-19 14:22:24 +01:00
def setup():
2023-03-24 20:53:56 +01:00
Thread(target=run_server, daemon=True).start()