import asyncio import json import ssl from threading import Thread from websockets.server import serve from extensions.api.util import ( build_parameters, try_start_cloudflared, with_api_lock ) from modules import shared from modules.chat import generate_chat_reply from modules.text_generation import generate_reply from modules.logging_colors import logger PATH = '/api/v1/stream' @with_api_lock async def _handle_stream_message(websocket, message): message = json.loads(message) prompt = message['prompt'] generate_params = build_parameters(message) stopping_strings = generate_params.pop('stopping_strings') generate_params['stream'] = True generator = generate_reply( prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) # As we stream, only send the new bytes. skip_index = 0 message_num = 0 for a in generator: to_send = a[skip_index:] if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet. continue await websocket.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'text': to_send })) await asyncio.sleep(0) skip_index += len(to_send) message_num += 1 await websocket.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num })) @with_api_lock async def _handle_chat_stream_message(websocket, message): body = json.loads(message) user_input = body['user_input'] generate_params = build_parameters(body, chat=True) generate_params['stream'] = True regenerate = body.get('regenerate', False) _continue = body.get('_continue', False) generator = generate_chat_reply( user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) message_num = 0 for a in generator: await websocket.send(json.dumps({ 'event': 'text_stream', 'message_num': message_num, 'history': a })) await asyncio.sleep(0) message_num += 1 await websocket.send(json.dumps({ 'event': 'stream_end', 'message_num': message_num })) async def _handle_connection(websocket, path): if path == '/api/v1/stream': async for message in websocket: await _handle_stream_message(websocket, message) elif path == '/api/v1/chat-stream': async for message in websocket: await _handle_chat_stream_message(websocket, message) else: print(f'Streaming api: unknown path: {path}') return async def _run(host: str, port: int): ssl_certfile = shared.args.ssl_certfile ssl_keyfile = shared.args.ssl_keyfile ssl_verify = True if (ssl_keyfile and ssl_certfile) else False if ssl_verify: context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context.load_cert_chain(ssl_certfile, ssl_keyfile) else: context = None async with serve(_handle_connection, host, port, ping_interval=None, ssl=context): await asyncio.Future() # Run the server forever def _run_server(port: int, share: bool = False, tunnel_id=str): address = '0.0.0.0' if shared.args.listen else '127.0.0.1' ssl_certfile = shared.args.ssl_certfile ssl_keyfile = shared.args.ssl_keyfile ssl_verify = True if (ssl_keyfile and ssl_certfile) else False def on_start(public_url: str): public_url = public_url.replace('https://', 'wss://') logger.info(f'Streaming API URL: \n\n{public_url}{PATH}\n') if share: try: try_start_cloudflared(port, tunnel_id, max_attempts=3, on_start=on_start) except Exception as e: print(e) else: if ssl_verify: logger.info(f'Streaming API URL: \n\nwss://{address}:{port}{PATH}\n') else: logger.info(f'Streaming API URL: \n\nws://{address}:{port}{PATH}\n') asyncio.run(_run(host=address, port=port)) def start_server(port: int, share: bool = False, tunnel_id=str): Thread(target=_run_server, args=[port, share, tunnel_id], daemon=True).start()