From ec17a5d2b7f0fb3b4cc75f84c71393fe8cb9bd14 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 6 Nov 2023 02:38:29 -0300 Subject: [PATCH] Make OpenAI API the default API (#4430) --- README.md | 6 +- api-examples/api-example-chat-stream.py | 114 --- api-examples/api-example-chat.py | 94 --- api-examples/api-example-stream.py | 88 --- api-examples/api-example.py | 65 -- .../README.md => docs/12 - OpenAI API.md | 204 ++--- extensions/api/script.py | 2 + extensions/openai/completions.py | 744 ++++++++---------- extensions/openai/defaults.py | 78 -- extensions/openai/edits.py | 101 --- extensions/openai/embeddings.py | 12 +- extensions/openai/images.py | 1 + extensions/openai/requirements.txt | 3 +- extensions/openai/script.py | 454 +++++------ extensions/openai/typing.py | 125 +++ extensions/openai/utils.py | 36 +- modules/chat.py | 7 +- modules/models.py | 4 +- modules/presets.py | 23 +- modules/shared.py | 32 +- modules/text_generation.py | 5 +- modules/ui_model_menu.py | 3 - 22 files changed, 769 insertions(+), 1432 deletions(-) delete mode 100644 api-examples/api-example-chat-stream.py delete mode 100644 api-examples/api-example-chat.py delete mode 100644 api-examples/api-example-stream.py delete mode 100644 api-examples/api-example.py rename extensions/openai/README.md => docs/12 - OpenAI API.md (56%) delete mode 100644 extensions/openai/defaults.py delete mode 100644 extensions/openai/edits.py create mode 100644 extensions/openai/typing.py diff --git a/README.md b/README.md index 79d86a15..f8a691a0 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/03-%E2%80%90-Parameters-Tab#character) * Very efficient text streaming * Markdown output with LaTeX rendering, to use for instance with [GALACTICA](https://github.com/paperswithcode/galai) -* API, including endpoints for websocket streaming ([see the examples](https://github.com/oobabooga/text-generation-webui/blob/main/api-examples)) +* OpenAI-compatible API server ## Documentation @@ -412,8 +412,8 @@ Optionally, you can use the following command-line flags: | `--api` | Enable the API extension. | | `--public-api` | Create a public URL for the API using Cloudfare. | | `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. | -| `--api-blocking-port BLOCKING_PORT` | The listening port for the blocking API. | -| `--api-streaming-port STREAMING_PORT` | The listening port for the streaming API. | +| `--api-port API_PORT` | The listening port for the API. | +| `--api-key API_KEY` | API authentication key. | #### Multimodal diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py deleted file mode 100644 index 3a1502dd..00000000 --- a/api-examples/api-example-chat-stream.py +++ /dev/null @@ -1,114 +0,0 @@ -import asyncio -import html -import json -import sys - -try: - import websockets -except ImportError: - print("Websockets package not found. Make sure it's installed.") - -# For local streaming, the websockets are hosted without ssl - ws:// -HOST = 'localhost:5005' -URI = f'ws://{HOST}/api/v1/chat-stream' - -# For reverse-proxied streaming, the remote will likely host with ssl - wss:// -# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' - - -async def run(user_input, history): - # Note: the selected defaults change from time to time. - request = { - 'user_input': user_input, - 'max_new_tokens': 250, - 'auto_max_new_tokens': False, - 'max_tokens_second': 0, - 'history': history, - 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' - 'character': 'Example', - 'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset - 'your_name': 'You', - # 'name1': 'name of user', # Optional - # 'name2': 'name of character', # Optional - # 'context': 'character context', # Optional - # 'greeting': 'greeting', # Optional - # 'name1_instruct': 'You', # Optional - # 'name2_instruct': 'Assistant', # Optional - # 'context_instruct': 'context_instruct', # Optional - # 'turn_template': 'turn_template', # Optional - 'regenerate': False, - '_continue': False, - 'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', - - # Generation params. If 'preset' is set to different than 'None', the values - # in presets/preset-name.yaml are used instead of the individual numbers. - 'preset': 'None', - 'do_sample': True, - 'temperature': 0.7, - 'top_p': 0.1, - 'typical_p': 1, - 'epsilon_cutoff': 0, # In units of 1e-4 - 'eta_cutoff': 0, # In units of 1e-4 - 'tfs': 1, - 'top_a': 0, - 'repetition_penalty': 1.18, - 'presence_penalty': 0, - 'frequency_penalty': 0, - 'repetition_penalty_range': 0, - 'top_k': 40, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - 'mirostat_mode': 0, - 'mirostat_tau': 5, - 'mirostat_eta': 0.1, - 'grammar_string': '', - 'guidance_scale': 1, - 'negative_prompt': '', - - 'seed': -1, - 'add_bos_token': True, - 'truncation_length': 2048, - 'ban_eos_token': False, - 'custom_token_bans': '', - 'skip_special_tokens': True, - 'stopping_strings': [] - } - - async with websockets.connect(URI, ping_interval=None) as websocket: - await websocket.send(json.dumps(request)) - - while True: - incoming_data = await websocket.recv() - incoming_data = json.loads(incoming_data) - - match incoming_data['event']: - case 'text_stream': - yield incoming_data['history'] - case 'stream_end': - return - - -async def print_response_stream(user_input, history): - cur_len = 0 - async for new_history in run(user_input, history): - cur_message = new_history['visible'][-1][1][cur_len:] - cur_len += len(cur_message) - print(html.unescape(cur_message), end='') - sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. - - -if __name__ == '__main__': - user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard." - - # Basic example - history = {'internal': [], 'visible': []} - - # "Continue" example. Make sure to set '_continue' to True above - # arr = [user_input, 'Surely, here is'] - # history = {'internal': [arr], 'visible': [arr]} - - asyncio.run(print_response_stream(user_input, history)) diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py deleted file mode 100644 index 0f7a44aa..00000000 --- a/api-examples/api-example-chat.py +++ /dev/null @@ -1,94 +0,0 @@ -import html -import json - -import requests - -# For local streaming, the websockets are hosted without ssl - http:// -HOST = 'localhost:5000' -URI = f'http://{HOST}/api/v1/chat' - -# For reverse-proxied streaming, the remote will likely host with ssl - https:// -# URI = 'https://your-uri-here.trycloudflare.com/api/v1/chat' - - -def run(user_input, history): - request = { - 'user_input': user_input, - 'max_new_tokens': 250, - 'auto_max_new_tokens': False, - 'max_tokens_second': 0, - 'history': history, - 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' - 'character': 'Example', - 'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset - 'your_name': 'You', - # 'name1': 'name of user', # Optional - # 'name2': 'name of character', # Optional - # 'context': 'character context', # Optional - # 'greeting': 'greeting', # Optional - # 'name1_instruct': 'You', # Optional - # 'name2_instruct': 'Assistant', # Optional - # 'context_instruct': 'context_instruct', # Optional - # 'turn_template': 'turn_template', # Optional - 'regenerate': False, - '_continue': False, - 'chat_instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', - - # Generation params. If 'preset' is set to different than 'None', the values - # in presets/preset-name.yaml are used instead of the individual numbers. - 'preset': 'None', - 'do_sample': True, - 'temperature': 0.7, - 'top_p': 0.1, - 'typical_p': 1, - 'epsilon_cutoff': 0, # In units of 1e-4 - 'eta_cutoff': 0, # In units of 1e-4 - 'tfs': 1, - 'top_a': 0, - 'repetition_penalty': 1.18, - 'presence_penalty': 0, - 'frequency_penalty': 0, - 'repetition_penalty_range': 0, - 'top_k': 40, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - 'mirostat_mode': 0, - 'mirostat_tau': 5, - 'mirostat_eta': 0.1, - 'grammar_string': '', - 'guidance_scale': 1, - 'negative_prompt': '', - - 'seed': -1, - 'add_bos_token': True, - 'truncation_length': 2048, - 'ban_eos_token': False, - 'custom_token_bans': '', - 'skip_special_tokens': True, - 'stopping_strings': [] - } - - response = requests.post(URI, json=request) - - if response.status_code == 200: - result = response.json()['results'][0]['history'] - print(json.dumps(result, indent=4)) - print() - print(html.unescape(result['visible'][-1][1])) - - -if __name__ == '__main__': - user_input = "Please give me a step-by-step guide on how to plant a tree in my backyard." - - # Basic example - history = {'internal': [], 'visible': []} - - # "Continue" example. Make sure to set '_continue' to True above - # arr = [user_input, 'Surely, here is'] - # history = {'internal': [arr], 'visible': [arr]} - - run(user_input, history) diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py deleted file mode 100644 index 4d5cb725..00000000 --- a/api-examples/api-example-stream.py +++ /dev/null @@ -1,88 +0,0 @@ -import asyncio -import json -import sys - -try: - import websockets -except ImportError: - print("Websockets package not found. Make sure it's installed.") - -# For local streaming, the websockets are hosted without ssl - ws:// -HOST = 'localhost:5005' -URI = f'ws://{HOST}/api/v1/stream' - -# For reverse-proxied streaming, the remote will likely host with ssl - wss:// -# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' - - -async def run(context): - # Note: the selected defaults change from time to time. - request = { - 'prompt': context, - 'max_new_tokens': 250, - 'auto_max_new_tokens': False, - 'max_tokens_second': 0, - - # Generation params. If 'preset' is set to different than 'None', the values - # in presets/preset-name.yaml are used instead of the individual numbers. - 'preset': 'None', - 'do_sample': True, - 'temperature': 0.7, - 'top_p': 0.1, - 'typical_p': 1, - 'epsilon_cutoff': 0, # In units of 1e-4 - 'eta_cutoff': 0, # In units of 1e-4 - 'tfs': 1, - 'top_a': 0, - 'repetition_penalty': 1.18, - 'presence_penalty': 0, - 'frequency_penalty': 0, - 'repetition_penalty_range': 0, - 'top_k': 40, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - 'mirostat_mode': 0, - 'mirostat_tau': 5, - 'mirostat_eta': 0.1, - 'grammar_string': '', - 'guidance_scale': 1, - 'negative_prompt': '', - - 'seed': -1, - 'add_bos_token': True, - 'truncation_length': 2048, - 'ban_eos_token': False, - 'custom_token_bans': '', - 'skip_special_tokens': True, - 'stopping_strings': [] - } - - async with websockets.connect(URI, ping_interval=None) as websocket: - await websocket.send(json.dumps(request)) - - yield context # Remove this if you just want to see the reply - - while True: - incoming_data = await websocket.recv() - incoming_data = json.loads(incoming_data) - - match incoming_data['event']: - case 'text_stream': - yield incoming_data['text'] - case 'stream_end': - return - - -async def print_response_stream(prompt): - async for response in run(prompt): - print(response, end='') - sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. - - -if __name__ == '__main__': - prompt = "In order to make homemade bread, follow these steps:\n1)" - asyncio.run(print_response_stream(prompt)) diff --git a/api-examples/api-example.py b/api-examples/api-example.py deleted file mode 100644 index bdcfcea3..00000000 --- a/api-examples/api-example.py +++ /dev/null @@ -1,65 +0,0 @@ -import requests - -# For local streaming, the websockets are hosted without ssl - http:// -HOST = 'localhost:5000' -URI = f'http://{HOST}/api/v1/generate' - -# For reverse-proxied streaming, the remote will likely host with ssl - https:// -# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate' - - -def run(prompt): - request = { - 'prompt': prompt, - 'max_new_tokens': 250, - 'auto_max_new_tokens': False, - 'max_tokens_second': 0, - - # Generation params. If 'preset' is set to different than 'None', the values - # in presets/preset-name.yaml are used instead of the individual numbers. - 'preset': 'None', - 'do_sample': True, - 'temperature': 0.7, - 'top_p': 0.1, - 'typical_p': 1, - 'epsilon_cutoff': 0, # In units of 1e-4 - 'eta_cutoff': 0, # In units of 1e-4 - 'tfs': 1, - 'top_a': 0, - 'repetition_penalty': 1.18, - 'presence_penalty': 0, - 'frequency_penalty': 0, - 'repetition_penalty_range': 0, - 'top_k': 40, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - 'mirostat_mode': 0, - 'mirostat_tau': 5, - 'mirostat_eta': 0.1, - 'grammar_string': '', - 'guidance_scale': 1, - 'negative_prompt': '', - - 'seed': -1, - 'add_bos_token': True, - 'truncation_length': 2048, - 'ban_eos_token': False, - 'custom_token_bans': '', - 'skip_special_tokens': True, - 'stopping_strings': [] - } - - response = requests.post(URI, json=request) - - if response.status_code == 200: - result = response.json()['results'][0]['text'] - print(prompt + result) - - -if __name__ == '__main__': - prompt = "In order to make homemade bread, follow these steps:\n1)" - run(prompt) diff --git a/extensions/openai/README.md b/docs/12 - OpenAI API.md similarity index 56% rename from extensions/openai/README.md rename to docs/12 - OpenAI API.md index 82026ac1..ec424a3c 100644 --- a/extensions/openai/README.md +++ b/docs/12 - OpenAI API.md @@ -1,124 +1,64 @@ -# An OpenedAI API (openai like) +## OpenAI compatible API -This extension creates an API that works kind of like openai (ie. api.openai.com). +This project includes an API compatible with multiple OpenAI endpoints, including Chat and Completions. -## Setup & installation - -Install the requirements: +If you did not use the one-click installers, you may need to install the requirements first: ``` -pip3 install -r requirements.txt +pip install -r extensions/openai/requirements.txt ``` -It listens on `tcp port 5001` by default. You can use the `OPENEDAI_PORT` environment variable to change this. +### Starting the API -Make sure you enable it in server launch parameters, it should include: +Add `--extensions openai` to your command-line flags. + +* To create a public Cloudflare URL, add the `--public-api` flag. +* To listen on your local network, add the `--listen` flag. +* To change the port, which is 5000 by default, use `--port 1234` (change 1234 to your desired port number). +* To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. Note that it doesn't work with `--public-api`. + +#### Environment variables + +The following environment variables can be used (they take precendence over everything else): + +| Variable Name | Description | Example Value | +|------------------------|------------------------------------|----------------------------| +| `OPENEDAI_PORT` | Port number | 5000 | +| `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem | +| `OPENEDAI_KEY_PATH` | SSL key file path | key.pem | +| `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 | +| `SD_WEBUI_URL` | WebUI URL (used by endpoint) | http://127.0.0.1:7861 | +| `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | all-mpnet-base-v2 | +| `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda | + +#### Persistent settings with `settings.yaml` + +You can also set default values by adding these lines to your `settings.yaml` file: ``` ---extensions openai -``` - -You can also use the `--listen` argument to make the server available on the networ, and/or the `--share` argument to enable a public Cloudflare endpoint. - -To enable the basic image generation support (txt2img) set the environment variable `SD_WEBUI_URL` to point to your Stable Diffusion API ([Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui)). - -For example: - -``` -SD_WEBUI_URL=http://127.0.0.1:7861 -``` - -## Quick start - -1. Install the requirements.txt (pip) -2. Enable the `openeai` module (--extensions openai), restart the server. -3. Configure the openai client - -Most openai application can be configured to connect the API if you set the following environment variables: - -```shell -# Sample .env file: -OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111 -OPENAI_API_BASE=http://0.0.0.0:5001/v1 -``` - -If needed, replace 0.0.0.0 with the IP/port of your server. - - -### Settings - -To adjust your default settings, you can add the following to your `settings.yaml` file. - -``` -openai-port: 5002 openai-embedding_device: cuda +openai-embedding_model: all-mpnet-base-v2 openai-sd_webui_url: http://127.0.0.1:7861 openai-debug: 1 ``` -If you've configured the environment variables, please note that settings from `settings.yaml` won't take effect. For instance, if you set `openai-port: 5002` in `settings.yaml` but `OPENEDAI_PORT=5001` in the environment variables, the extension will use `5001` as the port number. - -When using `cache_embedding_model.py` to preload the embedding model during Docker image building, consider the following: - -- If you wish to use the default settings, leave the environment variables unset. -- If you intend to change the default embedding model, ensure that you configure the environment variable `OPENEDAI_EMBEDDING_MODEL` to the desired model. Avoid setting `openai-embedding_model` in `settings.yaml` because those settings only take effect after the server starts. - -### Models - -This has been successfully tested with Alpaca, Koala, Vicuna, WizardLM and their variants, (ex. gpt4-x-alpaca, GPT4all-snoozy, stable-vicuna, wizard-vicuna, etc.) and many others. Models that have been trained for **Instruction Following** work best. If you test with other models please let me know how it goes. Less than satisfying results (so far) from: RWKV-4-Raven, llama, mpt-7b-instruct/chat. - -For best results across all API endpoints, a model like [vicuna-13b-v1.3-GPTQ](https://huggingface.co/TheBloke/vicuna-13b-v1.3-GPTQ), [stable-vicuna-13B-GPTQ](https://huggingface.co/TheBloke/stable-vicuna-13B-GPTQ) or [airoboros-13B-gpt4-1.3-GPTQ](https://huggingface.co/TheBloke/airoboros-13B-gpt4-1.3-GPTQ) is a good start. - -For good results with the [Completions](https://platform.openai.com/docs/api-reference/completions) API endpoint, in addition to the above models, you can also try using a base model like [falcon-7b](https://huggingface.co/tiiuae/falcon-7b) or Llama. - -For good results with the [ChatCompletions](https://platform.openai.com/docs/api-reference/chat) or [Edits](https://platform.openai.com/docs/api-reference/edits) API endpoints you can use almost any model trained for instruction following. Be sure that the proper instruction template is detected and loaded or the results will not be good. - -For the proper instruction format to be detected you need to have a matching model entry in your `models/config.yaml` file. Be sure to keep this file up to date. -A matching instruction template file in the characters/instruction-following/ folder will loaded and applied to format messages correctly for the model - this is critical for good results. - -For example, the Wizard-Vicuna family of models are trained with the Vicuna 1.1 format. In the models/config.yaml file there is this matching entry: - -``` -.*wizard.*vicuna: - mode: 'instruct' - instruction_template: 'Vicuna-v1.1' -``` - -This refers to `characters/instruction-following/Vicuna-v1.1.yaml`, which looks like this: - -``` -user: "USER:" -bot: "ASSISTANT:" -turn_template: "<|user|> <|user-message|>\n<|bot|> <|bot-message|>\n" -context: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n" -``` - -For most common models this is already setup, but if you are using a new or uncommon model you may need add a matching entry to the models/config.yaml and possibly create your own instruction-following template and for best results. - -If you see this in your logs, it probably means that the correct format could not be loaded: - -``` -Warning: Loaded default instruction-following template for model. -``` - -### Embeddings (alpha) - -Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future. - -| model name | dimensions | input max tokens | speed | size | Avg. performance | -| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- | -| text-embedding-ada-002 | 1536 | 8192 | - | - | - | -| text-davinci-002 | 768 | 2046 | - | - | - | -| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 | -| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 | - -In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2". - -Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable. +### Examples ### Client Application Setup -Almost everything you use it with will require you to set a dummy OpenAI API key environment variable. + +You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables: + +```shell +OPENAI_API_HOST=http://127.0.0.1:5000 +``` + +or + +```shell +OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111 +OPENAI_API_BASE=http://127.0.0.1:500/v1 +``` With the [official python openai client](https://github.com/openai/openai-python), set the `OPENAI_API_BASE` environment variables: @@ -128,7 +68,7 @@ OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111 OPENAI_API_BASE=http://0.0.0.0:5001/v1 ``` -If needed, replace 0.0.0.0 with the IP/port of your server. +If needed, replace 127.0.0.1 with the IP/port of your server. If using .env files to save the `OPENAI_API_BASE` and `OPENAI_API_KEY` variables, make sure the .env file is loaded before the openai module is imported: @@ -157,8 +97,22 @@ const api = new ChatGPTAPI({ apiBaseUrl: process.env.OPENAI_API_BASE }); ``` +### Embeddings (alpha) -## API Documentation & Examples +Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings (the same as the text-davinci-002 embeddings), which is different from OpenAI's current default `text-embedding-ada-002` model which produces 1536 dimensional embeddings. The model is small-ish and fast-ish. This model and embedding size may change in the future. + +| model name | dimensions | input max tokens | speed | size | Avg. performance | +| ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- | +| text-embedding-ada-002 | 1536 | 8192 | - | - | - | +| text-davinci-002 | 768 | 2046 | - | - | - | +| all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 | +| all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 | + +In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2". + +Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable. + +### API Documentation & Examples The OpenAI API is well documented, you can view the documentation here: https://platform.openai.com/docs/api-reference @@ -185,7 +139,7 @@ text = response['choices'][0]['message']['content'] print(text) ``` -## Compatibility & not so compatibility +### Compatibility & not so compatibility | API endpoint | tested with | notes | | ------------------------- | ---------------------------------- | --------------------------------------------------------------------------- | @@ -195,7 +149,7 @@ print(text) | /v1/moderations | openai.Moderation.create() | Basic initial support via embeddings | | /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options | | /v1/models/{id} | openai.Model.get() | returns whatever you ask for | -| /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models | +| /v1/edits | openai.Edit.create() | Removed, use /v1/chat/completions instead | | /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model | | /v1/completions | openai api completions.create | Legacy endpoint (v0.25) | | /v1/engines/\*/embeddings | python-openai v0.25 | Legacy endpoint | @@ -209,28 +163,8 @@ print(text) | /v1/fine-tunes\* | openai.FineTune.\* | not yet supported | | /v1/search | openai.search, engines.search | not yet supported | -Because of the differences in OpenAI model context sizes (2k, 4k, 8k, 16k, etc,) you may need to adjust the max_tokens to fit into the context of the model you choose. -Streaming, temperature, top_p, max_tokens, stop, should all work as expected, but not all parameters are mapped correctly. - -Some hacky mappings: - -| OpenAI | text-generation-webui | note | -| ----------------------- | -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| model | - | Ignored, the model is not changed | -| frequency_penalty | encoder_repetition_penalty | this seems to operate with a different scale and defaults, I tried to scale it based on range & defaults, but the results are terrible. hardcoded to 1.18 until there is a better way | -| presence_penalty | repetition_penalty | same issues as frequency_penalty, hardcoded to 1.0 | -| best_of | top_k | default is 1 (top_k is 20 for chat, which doesn't support best_of) | -| n | 1 | variations are not supported yet. | -| 1 | num_beams | hardcoded to 1 | -| 1.0 | typical_p | hardcoded to 1.0 | -| logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' | -| messages.name | - | not supported yet | -| suffix | - | not supported yet | -| user | - | not supported yet | -| functions/function_call | - | function calls are not supported yet | - -### Applications +#### Applications Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variable set, but there are some exceptions. @@ -249,15 +183,3 @@ Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment v | ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context | | ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 | | ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported | - -## Future plans - -- better error handling -- model changing, esp. something for swapping loras or embedding models -- consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard) - -## Bugs? Feedback? Comments? Pull requests? - -To enable debugging and get copious output you can set the `OPENEDAI_DEBUG=1` environment variable. - -Are all appreciated, please @matatonic and I'll try to get back to you as soon as possible. diff --git a/extensions/api/script.py b/extensions/api/script.py index 12fd9cad..3c2fc898 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -3,9 +3,11 @@ import time import extensions.api.blocking_api as blocking_api import extensions.api.streaming_api as streaming_api from modules import shared +from modules.logging_colors import logger def setup(): + logger.warning("The current API is deprecated and will be replaced with the OpenAI compatible API on November xxth. To test the new API, use \"--extensions openai\" instead of \"--api\".") blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api, tunnel_id=shared.args.public_api_id) if shared.args.public_api: time.sleep(5) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 40d96c1f..5eb0e291 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -1,18 +1,23 @@ +import copy import time +from collections import deque import tiktoken import torch import torch.nn.functional as F -import yaml -from extensions.openai.defaults import clamp, default, get_default_req_params from extensions.openai.errors import InvalidRequestError -from extensions.openai.utils import debug_msg, end_line +from extensions.openai.utils import debug_msg from modules import shared +from modules.chat import ( + generate_chat_prompt, + generate_chat_reply, + load_character_memoized +) +from modules.presets import load_preset_memoized from modules.text_generation import decode, encode, generate_reply from transformers import LogitsProcessor, LogitsProcessorList -# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic class LogitsBiasProcessor(LogitsProcessor): def __init__(self, logit_bias={}): self.logit_bias = logit_bias @@ -28,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor): logits[0, self.keys] += self.values debug_msg(" --> ", logits[0, self.keys]) debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0]))) + return logits def __repr__(self): @@ -47,6 +53,7 @@ class LogprobProcessor(LogitsProcessor): top_probs = [float(x) for x in top_values[0]] self.token_alternatives = dict(zip(top_tokens, top_probs)) debug_msg(repr(self)) + return logits def __repr__(self): @@ -66,43 +73,28 @@ def convert_logprobs_to_tiktoken(model, logprobs): return logprobs -def marshal_common_params(body): - # Request Parameters - # Try to use openai defaults or map them to something with the same intent +def process_parameters(body, is_legacy=False): + generate_params = body + max_tokens_str = 'length' if is_legacy else 'max_tokens' + generate_params['max_new_tokens'] = body.pop(max_tokens_str) + if generate_params['truncation_length'] == 0: + if shared.args.loader and shared.args.loader.lower().startswith('exllama'): + generate_params['truncation_length'] = shared.args.max_seq_len + elif shared.args.loader and shared.args.loader in ['llama.cpp', 'llamacpp_HF', 'ctransformers']: + generate_params['truncation_length'] = shared.args.n_ctx + else: + generate_params['truncation_length'] = shared.settings['truncation_length'] - req_params = get_default_req_params() - - # Common request parameters - req_params['truncation_length'] = shared.settings['truncation_length'] - req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token']) - req_params['seed'] = shared.settings.get('seed', req_params['seed']) - req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings'] - - # OpenAI API Parameters - # model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this - req_params['requested_model'] = body.get('model', shared.model_name) - - req_params['suffix'] = default(body, 'suffix', req_params['suffix']) - req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0 - req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0) - n = default(body, 'n', 1) - if n != 1: - raise InvalidRequestError(message="Only n = 1 is supported.", param='n') + if body['preset'] is not None: + preset = load_preset_memoized(body['preset']) + generate_params.update(preset) + generate_params['custom_stopping_strings'] = [] if 'stop' in body: # str or array, max len 4 (ignored) if isinstance(body['stop'], str): - req_params['stopping_strings'] = [body['stop']] # non-standard parameter + generate_params['custom_stopping_strings'] = [body['stop']] elif isinstance(body['stop'], list): - req_params['stopping_strings'] = body['stop'] - - # presence_penalty - ignored - # frequency_penalty - ignored - - # pass through unofficial params - req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty']) - req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty']) - - # user - ignored + generate_params['custom_stopping_strings'] = body['stop'] logits_processor = [] logit_bias = body.get('logit_bias', None) @@ -110,12 +102,13 @@ def marshal_common_params(body): # XXX convert tokens from tiktoken based on requested model # Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100} try: - encoder = tiktoken.encoding_for_model(req_params['requested_model']) + encoder = tiktoken.encoding_for_model(generate_params['model']) new_logit_bias = {} for logit, bias in logit_bias.items(): for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]: if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens continue + new_logit_bias[str(int(x))] = bias debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias) logit_bias = new_logit_bias @@ -126,238 +119,129 @@ def marshal_common_params(body): logprobs = None # coming to chat eventually if 'logprobs' in body: - logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5. - req_params['logprob_proc'] = LogprobProcessor(logprobs) - logits_processor.extend([req_params['logprob_proc']]) + logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5. + generate_params['logprob_proc'] = LogprobProcessor(logprobs) + logits_processor.extend([generate_params['logprob_proc']]) else: logprobs = None if logits_processor: # requires logits_processor support - req_params['logits_processor'] = LogitsProcessorList(logits_processor) + generate_params['logits_processor'] = LogitsProcessorList(logits_processor) - return req_params + return generate_params -def messages_to_prompt(body: dict, req_params: dict, max_tokens): - # functions - if body.get('functions', []): # chat only +def convert_history(history): + ''' + Chat histories in this program are in the format [message, reply]. + This function converts OpenAI histories to that format. + ''' + chat_dialogue = [] + current_message = "" + current_reply = "" + user_input = "" + + for entry in history: + content = entry["content"] + role = entry["role"] + + if role == "user": + user_input = content + if current_message: + chat_dialogue.append([current_message, '']) + current_message = "" + current_message = content + elif role == "assistant": + current_reply = content + if current_message: + chat_dialogue.append([current_message, current_reply]) + current_message = "" + current_reply = "" + else: + chat_dialogue.append(['', current_reply]) + + # if current_message: + # chat_dialogue.append([current_message, '']) + + return user_input, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)} + + +def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict: + if body.get('functions', []): raise InvalidRequestError(message="functions is not supported.", param='functions') - if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'} + + if body.get('function_call', ''): raise InvalidRequestError(message="function_call is not supported.", param='function_call') if 'messages' not in body: raise InvalidRequestError(message="messages is required", param='messages') messages = body['messages'] - - role_formats = { - 'user': 'User: {message}\n', - 'assistant': 'Assistant: {message}\n', - 'system': '{message}', - 'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?', - 'prompt': 'Assistant:', - } - - if 'stopping_strings' not in req_params: - req_params['stopping_strings'] = [] - - # Instruct models can be much better - if shared.settings['instruction_template']: - try: - instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r')) - - template = instruct['turn_template'] - system_message_template = "{message}" - system_message_default = instruct.get('context', '') # can be missing - bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token - user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', '')) - bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', '')) - bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ') - - role_formats = { - 'user': user_message_template, - 'assistant': bot_message_template, - 'system': system_message_template, - 'context': system_message_default, - 'prompt': bot_prompt, - } - - if 'Alpaca' in shared.settings['instruction_template']: - req_params['stopping_strings'].extend(['\n###']) - elif instruct['user']: # WizardLM and some others have no user prompt. - req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']]) - - debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}") - - except Exception as e: - req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also - - print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}") - print("Warning: Loaded default instruction-following template for model.") - - else: - req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also - print("Warning: Loaded default instruction-following template for model.") - - system_msgs = [] - chat_msgs = [] - - # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} - context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else '' - context_msg = end_line(context_msg) - - # Maybe they sent both? This is not documented in the API, but some clients seem to do this. - if 'prompt' in body: - context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg - for m in messages: if 'role' not in m: raise InvalidRequestError(message="messages: missing role", param='messages') + elif m['role'] == 'function': + raise InvalidRequestError(message="role: function is not supported.", param='messages') if 'content' not in m: raise InvalidRequestError(message="messages: missing content", param='messages') - role = m['role'] - content = m['content'] - # name = m.get('name', None) - # function_call = m.get('function_call', None) # user name or function name with output in content - msg = role_formats[role].format(message=content) - if role == 'system': - system_msgs.extend([msg]) - elif role == 'function': - raise InvalidRequestError(message="role: function is not supported.", param='messages') - else: - chat_msgs.extend([msg]) - - system_msg = '\n'.join(system_msgs) - system_msg = end_line(system_msg) - - prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt'] - - token_count = len(encode(prompt)[0]) - - if token_count >= req_params['truncation_length']: - err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens." - raise InvalidRequestError(message=err_msg, param='messages') - - if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']: - err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}." - print(f"Warning: ${err_msg}") - # raise InvalidRequestError(message=err_msg, params='max_tokens') - - return prompt, token_count - - -def chat_completions(body: dict, is_legacy: bool = False) -> dict: # Chat Completions - object_type = 'chat.completions' + object_type = 'chat.completions' if not stream else 'chat.completions.chunk' created_time = int(time.time()) cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) resp_list = 'data' if is_legacy else 'choices' - # common params - req_params = marshal_common_params(body) - req_params['stream'] = False - requested_model = req_params.pop('requested_model') - logprob_proc = req_params.pop('logprob_proc', None) - req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k. + # generation parameters + generate_params = process_parameters(body, is_legacy=is_legacy) + continue_ = body['continue_'] - # chat default max_tokens is 'inf', but also flexible - max_tokens = 0 - max_tokens_str = 'length' if is_legacy else 'max_tokens' - if max_tokens_str in body: - max_tokens = default(body, max_tokens_str, req_params['truncation_length']) - req_params['max_new_tokens'] = max_tokens - else: - req_params['max_new_tokens'] = req_params['truncation_length'] + # Instruction template + instruction_template = body['instruction_template'] or shared.settings['instruction_template'] + name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True) + name1_instruct = body['name1_instruct'] or name1_instruct + name2_instruct = body['name2_instruct'] or name2_instruct + context_instruct = body['context_instruct'] or context_instruct + turn_template = body['turn_template'] or turn_template - # format the prompt from messages - prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings'] + # Chat character + character = body['character'] or shared.settings['character'] + name1 = body['name1'] or shared.settings['name1'] + name1, name2, _, greeting, context, _ = load_character_memoized(character, name1, '', instruct=False) + name2 = body['name2'] or name2 + context = body['context'] or context + greeting = body['greeting'] or greeting - # set real max, avoid deeper errors - if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: - req_params['max_new_tokens'] = req_params['truncation_length'] - token_count + # History + user_input, history = convert_history(messages) - stopping_strings = req_params.pop('stopping_strings', []) + generate_params.update({ + 'mode': body['mode'], + 'name1': name1, + 'name2': name2, + 'context': context, + 'greeting': greeting, + 'name1_instruct': name1_instruct, + 'name2_instruct': name2_instruct, + 'context_instruct': context_instruct, + 'turn_template': turn_template, + 'chat-instruct_command': body['chat_instruct_command'], + 'history': history, + 'stream': stream + }) - # generate reply ####################################### - debug_msg({'prompt': prompt, 'req_params': req_params}) - generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + max_tokens = generate_params['max_new_tokens'] + if max_tokens in [None, 0]: + generate_params['max_new_tokens'] = 200 + generate_params['auto_max_new_tokens'] = True - answer = '' - for a in generator: - answer = a - - # strip extra leading space off new generated content - if answer and answer[0] == ' ': - answer = answer[1:] - - completion_token_count = len(encode(answer)[0]) - stop_reason = "stop" - if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']: - stop_reason = "length" - - resp = { - "id": cmpl_id, - "object": object_type, - "created": created_time, - "model": shared.model_name, # TODO: add Lora info? - resp_list: [{ - "index": 0, - "finish_reason": stop_reason, - "message": {"role": "assistant", "content": answer} - }], - "usage": { - "prompt_tokens": token_count, - "completion_tokens": completion_token_count, - "total_tokens": token_count + completion_token_count - } - } - if logprob_proc: # not official for chat yet - top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) - resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} - # else: - # resp[resp_list][0]["logprobs"] = None - - return resp - - -# generator -def stream_chat_completions(body: dict, is_legacy: bool = False): - - # Chat Completions - stream_object_type = 'chat.completions.chunk' - created_time = int(time.time()) - cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) - resp_list = 'data' if is_legacy else 'choices' - - # common params - req_params = marshal_common_params(body) - req_params['stream'] = True - requested_model = req_params.pop('requested_model') - logprob_proc = req_params.pop('logprob_proc', None) - req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k. - - # chat default max_tokens is 'inf', but also flexible - max_tokens = 0 - max_tokens_str = 'length' if is_legacy else 'max_tokens' - if max_tokens_str in body: - max_tokens = default(body, max_tokens_str, req_params['truncation_length']) - req_params['max_new_tokens'] = max_tokens - else: - req_params['max_new_tokens'] = req_params['truncation_length'] - - # format the prompt from messages - prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings'] - - # set real max, avoid deeper errors - if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: - req_params['max_new_tokens'] = req_params['truncation_length'] - token_count + requested_model = generate_params.pop('model') + logprob_proc = generate_params.pop('logprob_proc', None) def chat_streaming_chunk(content): # begin streaming chunk = { "id": cmpl_id, - "object": stream_object_type, + "object": object_type, "created": created_time, "model": shared.model_name, resp_list: [{ @@ -376,262 +260,262 @@ def stream_chat_completions(body: dict, is_legacy: bool = False): # chunk[resp_list][0]["logprobs"] = None return chunk - yield chat_streaming_chunk('') + if stream: + yield chat_streaming_chunk('') # generate reply ####################################### - debug_msg({'prompt': prompt, 'req_params': req_params}) + prompt = generate_chat_prompt(user_input, generate_params) + token_count = len(encode(prompt)[0]) + debug_msg({'prompt': prompt, 'generate_params': generate_params}) - stopping_strings = req_params.pop('stopping_strings', []) - - generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + generator = generate_chat_reply( + user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) answer = '' seen_content = '' completion_token_count = 0 for a in generator: - answer = a + answer = a['internal'][-1][1] + if stream: + len_seen = len(seen_content) + new_content = answer[len_seen:] - len_seen = len(seen_content) - new_content = answer[len_seen:] + if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. + continue - if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. - continue + seen_content = answer - seen_content = answer + # strip extra leading space off new generated content + if len_seen == 0 and new_content[0] == ' ': + new_content = new_content[1:] - # strip extra leading space off new generated content - if len_seen == 0 and new_content[0] == ' ': - new_content = new_content[1:] + chunk = chat_streaming_chunk(new_content) - chunk = chat_streaming_chunk(new_content) - - yield chunk - - # to get the correct token_count, strip leading space if present - if answer and answer[0] == ' ': - answer = answer[1:] + yield chunk completion_token_count = len(encode(answer)[0]) stop_reason = "stop" - if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']: + if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: stop_reason = "length" - chunk = chat_streaming_chunk('') - chunk[resp_list][0]['finish_reason'] = stop_reason - chunk['usage'] = { - "prompt_tokens": token_count, - "completion_tokens": completion_token_count, - "total_tokens": token_count + completion_token_count - } + if stream: + chunk = chat_streaming_chunk('') + chunk[resp_list][0]['finish_reason'] = stop_reason + chunk['usage'] = { + "prompt_tokens": token_count, + "completion_tokens": completion_token_count, + "total_tokens": token_count + completion_token_count + } - yield chunk + yield chunk + else: + resp = { + "id": cmpl_id, + "object": object_type, + "created": created_time, + "model": shared.model_name, + resp_list: [{ + "index": 0, + "finish_reason": stop_reason, + "message": {"role": "assistant", "content": answer} + }], + "usage": { + "prompt_tokens": token_count, + "completion_tokens": completion_token_count, + "total_tokens": token_count + completion_token_count + } + } + if logprob_proc: # not official for chat yet + top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) + resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} + # else: + # resp[resp_list][0]["logprobs"] = None + + yield resp -def completions(body: dict, is_legacy: bool = False): - # Legacy - # Text Completions - object_type = 'text_completion' +def completions_common(body: dict, is_legacy: bool = False, stream=False): + object_type = 'text_completion.chunk' if stream else 'text_completion' created_time = int(time.time()) cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) resp_list = 'data' if is_legacy else 'choices' - # ... encoded as a string, array of strings, array of tokens, or array of token arrays. prompt_str = 'context' if is_legacy else 'prompt' + + # ... encoded as a string, array of strings, array of tokens, or array of token arrays. if prompt_str not in body: raise InvalidRequestError("Missing required input", param=prompt_str) - prompt_arg = body[prompt_str] - if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)): - prompt_arg = [prompt_arg] - # common params - req_params = marshal_common_params(body) - req_params['stream'] = False - max_tokens_str = 'length' if is_legacy else 'max_tokens' - max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) - req_params['max_new_tokens'] = max_tokens - requested_model = req_params.pop('requested_model') - logprob_proc = req_params.pop('logprob_proc', None) - stopping_strings = req_params.pop('stopping_strings', []) - # req_params['suffix'] = default(body, 'suffix', req_params['suffix']) - req_params['echo'] = default(body, 'echo', req_params['echo']) - req_params['top_k'] = default(body, 'best_of', req_params['top_k']) + generate_params = process_parameters(body, is_legacy=is_legacy) + max_tokens = generate_params['max_new_tokens'] + generate_params['stream'] = stream + requested_model = generate_params.pop('model') + logprob_proc = generate_params.pop('logprob_proc', None) + # generate_params['suffix'] = body.get('suffix', generate_params['suffix']) + generate_params['echo'] = body.get('echo', generate_params['echo']) - resp_list_data = [] - total_completion_token_count = 0 - total_prompt_token_count = 0 + if not stream: + prompt_arg = body[prompt_str] + if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)): + prompt_arg = [prompt_arg] - for idx, prompt in enumerate(prompt_arg, start=0): - if isinstance(prompt[0], int): - # token lists - if requested_model == shared.model_name: - prompt = decode(prompt)[0] - else: + resp_list_data = [] + total_completion_token_count = 0 + total_prompt_token_count = 0 + + for idx, prompt in enumerate(prompt_arg, start=0): + if isinstance(prompt[0], int): + # token lists + if requested_model == shared.model_name: + prompt = decode(prompt)[0] + else: + try: + encoder = tiktoken.encoding_for_model(requested_model) + prompt = encoder.decode(prompt) + except KeyError: + prompt = decode(prompt)[0] + + token_count = len(encode(prompt)[0]) + total_prompt_token_count += token_count + + # generate reply ####################################### + debug_msg({'prompt': prompt, 'generate_params': generate_params}) + generator = generate_reply(prompt, generate_params, is_chat=False) + answer = '' + + for a in generator: + answer = a + + # strip extra leading space off new generated content + if answer and answer[0] == ' ': + answer = answer[1:] + + completion_token_count = len(encode(answer)[0]) + total_completion_token_count += completion_token_count + stop_reason = "stop" + if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: + stop_reason = "length" + + respi = { + "index": idx, + "finish_reason": stop_reason, + "text": answer, + "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, + } + + resp_list_data.extend([respi]) + + resp = { + "id": cmpl_id, + "object": object_type, + "created": created_time, + "model": shared.model_name, + resp_list: resp_list_data, + "usage": { + "prompt_tokens": total_prompt_token_count, + "completion_tokens": total_completion_token_count, + "total_tokens": total_prompt_token_count + total_completion_token_count + } + } + + yield resp + else: + prompt = body[prompt_str] + if isinstance(prompt, list): + if prompt and isinstance(prompt[0], int): try: encoder = tiktoken.encoding_for_model(requested_model) prompt = encoder.decode(prompt) except KeyError: prompt = decode(prompt)[0] + else: + raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) token_count = len(encode(prompt)[0]) - total_prompt_token_count += token_count - if token_count + max_tokens > req_params['truncation_length']: - err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." - # print(f"Warning: ${err_msg}") - raise InvalidRequestError(message=err_msg, param=max_tokens_str) + def text_streaming_chunk(content): + # begin streaming + chunk = { + "id": cmpl_id, + "object": object_type, + "created": created_time, + "model": shared.model_name, + resp_list: [{ + "index": 0, + "finish_reason": None, + "text": content, + "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, + }], + } + + return chunk + + yield text_streaming_chunk('') # generate reply ####################################### - debug_msg({'prompt': prompt, 'req_params': req_params}) - generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + debug_msg({'prompt': prompt, 'generate_params': generate_params}) + generator = generate_reply(prompt, generate_params, is_chat=False) + answer = '' + seen_content = '' + completion_token_count = 0 for a in generator: answer = a - # strip extra leading space off new generated content + len_seen = len(seen_content) + new_content = answer[len_seen:] + + if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. + continue + + seen_content = answer + + # strip extra leading space off new generated content + if len_seen == 0 and new_content[0] == ' ': + new_content = new_content[1:] + + chunk = text_streaming_chunk(new_content) + + yield chunk + + # to get the correct count, we strip the leading space if present if answer and answer[0] == ' ': answer = answer[1:] completion_token_count = len(encode(answer)[0]) - total_completion_token_count += completion_token_count stop_reason = "stop" - if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: + if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: stop_reason = "length" - respi = { - "index": idx, - "finish_reason": stop_reason, - "text": answer, - "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, + chunk = text_streaming_chunk('') + chunk[resp_list][0]["finish_reason"] = stop_reason + chunk["usage"] = { + "prompt_tokens": token_count, + "completion_tokens": completion_token_count, + "total_tokens": token_count + completion_token_count } - resp_list_data.extend([respi]) - - resp = { - "id": cmpl_id, - "object": object_type, - "created": created_time, - "model": shared.model_name, # TODO: add Lora info? - resp_list: resp_list_data, - "usage": { - "prompt_tokens": total_prompt_token_count, - "completion_tokens": total_completion_token_count, - "total_tokens": total_prompt_token_count + total_completion_token_count - } - } - - return resp - - -# generator -def stream_completions(body: dict, is_legacy: bool = False): - # Legacy - # Text Completions - # object_type = 'text_completion' - stream_object_type = 'text_completion.chunk' - created_time = int(time.time()) - cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) - resp_list = 'data' if is_legacy else 'choices' - - # ... encoded as a string, array of strings, array of tokens, or array of token arrays. - prompt_str = 'context' if is_legacy else 'prompt' - if prompt_str not in body: - raise InvalidRequestError("Missing required input", param=prompt_str) - - prompt = body[prompt_str] - req_params = marshal_common_params(body) - requested_model = req_params.pop('requested_model') - if isinstance(prompt, list): - if prompt and isinstance(prompt[0], int): - try: - encoder = tiktoken.encoding_for_model(requested_model) - prompt = encoder.decode(prompt) - except KeyError: - prompt = decode(prompt)[0] - else: - raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) - - # common params - req_params['stream'] = True - max_tokens_str = 'length' if is_legacy else 'max_tokens' - max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) - req_params['max_new_tokens'] = max_tokens - logprob_proc = req_params.pop('logprob_proc', None) - stopping_strings = req_params.pop('stopping_strings', []) - # req_params['suffix'] = default(body, 'suffix', req_params['suffix']) - req_params['echo'] = default(body, 'echo', req_params['echo']) - req_params['top_k'] = default(body, 'best_of', req_params['top_k']) - - token_count = len(encode(prompt)[0]) - - if token_count + max_tokens > req_params['truncation_length']: - err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." - # print(f"Warning: ${err_msg}") - raise InvalidRequestError(message=err_msg, param=max_tokens_str) - - def text_streaming_chunk(content): - # begin streaming - chunk = { - "id": cmpl_id, - "object": stream_object_type, - "created": created_time, - "model": shared.model_name, - resp_list: [{ - "index": 0, - "finish_reason": None, - "text": content, - "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, - }], - } - - return chunk - - yield text_streaming_chunk('') - - # generate reply ####################################### - debug_msg({'prompt': prompt, 'req_params': req_params}) - generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) - - answer = '' - seen_content = '' - completion_token_count = 0 - - for a in generator: - answer = a - - len_seen = len(seen_content) - new_content = answer[len_seen:] - - if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. - continue - - seen_content = answer - - # strip extra leading space off new generated content - if len_seen == 0 and new_content[0] == ' ': - new_content = new_content[1:] - - chunk = text_streaming_chunk(new_content) - yield chunk - # to get the correct count, we strip the leading space if present - if answer and answer[0] == ' ': - answer = answer[1:] - completion_token_count = len(encode(answer)[0]) - stop_reason = "stop" - if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: - stop_reason = "length" +def chat_completions(body: dict, is_legacy: bool = False) -> dict: + generator = chat_completions_common(body, is_legacy, stream=False) + return deque(generator, maxlen=1).pop() - chunk = text_streaming_chunk('') - chunk[resp_list][0]["finish_reason"] = stop_reason - chunk["usage"] = { - "prompt_tokens": token_count, - "completion_tokens": completion_token_count, - "total_tokens": token_count + completion_token_count - } - yield chunk +def stream_chat_completions(body: dict, is_legacy: bool = False): + for resp in chat_completions_common(body, is_legacy, stream=True): + yield resp + + +def completions(body: dict, is_legacy: bool = False) -> dict: + generator = completions_common(body, is_legacy, stream=False) + return deque(generator, maxlen=1).pop() + + +def stream_completions(body: dict, is_legacy: bool = False): + for resp in completions_common(body, is_legacy, stream=True): + yield resp diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py deleted file mode 100644 index dc400588..00000000 --- a/extensions/openai/defaults.py +++ /dev/null @@ -1,78 +0,0 @@ -import copy - -# Slightly different defaults for OpenAI's API -# Data type is important, Ex. use 0.0 for a float 0 -default_req_params = { - 'max_new_tokens': 16, # 'Inf' for chat - 'auto_max_new_tokens': False, - 'max_tokens_second': 0, - 'temperature': 1.0, - 'temperature_last': False, - 'top_p': 1.0, - 'min_p': 0, - 'top_k': 1, # choose 20 for chat in absence of another default - 'repetition_penalty': 1.18, - 'presence_penalty': 0, - 'frequency_penalty': 0, - 'repetition_penalty_range': 0, - 'encoder_repetition_penalty': 1.0, - 'suffix': None, - 'stream': False, - 'echo': False, - 'seed': -1, - # 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map - 'truncation_length': 2048, # first use shared.settings value - 'add_bos_token': True, - 'do_sample': True, - 'typical_p': 1.0, - 'epsilon_cutoff': 0.0, # In units of 1e-4 - 'eta_cutoff': 0.0, # In units of 1e-4 - 'tfs': 1.0, - 'top_a': 0.0, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0.0, - 'length_penalty': 1.0, - 'early_stopping': False, - 'mirostat_mode': 0, - 'mirostat_tau': 5.0, - 'mirostat_eta': 0.1, - 'grammar_string': '', - 'guidance_scale': 1, - 'negative_prompt': '', - 'ban_eos_token': False, - 'custom_token_bans': '', - 'skip_special_tokens': True, - 'custom_stopping_strings': '', - # 'logits_processor' - conditionally passed - # 'stopping_strings' - temporarily used - # 'logprobs' - temporarily used - # 'requested_model' - temporarily used -} - - -def get_default_req_params(): - return copy.deepcopy(default_req_params) - - -def default(dic, key, default): - ''' - little helper to get defaults if arg is present but None and should be the same type as default. - ''' - val = dic.get(key, default) - if not isinstance(val, type(default)): - # maybe it's just something like 1 instead of 1.0 - try: - v = type(default)(val) - if type(val)(v) == val: # if it's the same value passed in, it's ok. - return v - except: - pass - - val = default - return val - - -def clamp(value, minvalue, maxvalue): - return max(minvalue, min(value, maxvalue)) diff --git a/extensions/openai/edits.py b/extensions/openai/edits.py deleted file mode 100644 index edf4e6c0..00000000 --- a/extensions/openai/edits.py +++ /dev/null @@ -1,101 +0,0 @@ -import time - -import yaml -from extensions.openai.defaults import get_default_req_params -from extensions.openai.errors import InvalidRequestError -from extensions.openai.utils import debug_msg -from modules import shared -from modules.text_generation import encode, generate_reply - - -def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict: - - created_time = int(time.time() * 1000) - - # Request parameters - req_params = get_default_req_params() - stopping_strings = [] - - # Alpaca is verbose so a good default prompt - default_template = ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - ) - - instruction_template = default_template - - # Use the special instruction/input/response template for anything trained like Alpaca - if shared.settings['instruction_template']: - if 'Alpaca' in shared.settings['instruction_template']: - stopping_strings.extend(['\n###']) - else: - try: - instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r')) - - template = instruct['turn_template'] - template = template\ - .replace('<|user|>', instruct.get('user', ''))\ - .replace('<|bot|>', instruct.get('bot', ''))\ - .replace('<|user-message|>', '{instruction}\n{input}') - - instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') - if instruct['user']: - stopping_strings.extend(['\n' + instruct['user'], instruct['user']]) - - except Exception as e: - instruction_template = default_template - print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}") - print("Warning: Loaded default instruction-following template (Alpaca) for model.") - else: - stopping_strings.extend(['\n###']) - print("Warning: Loaded default instruction-following template (Alpaca) for model.") - - edit_task = instruction_template.format(instruction=instruction, input=input) - - truncation_length = shared.settings['truncation_length'] - - token_count = len(encode(edit_task)[0]) - max_tokens = truncation_length - token_count - - if max_tokens < 1: - err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens." - raise InvalidRequestError(err_msg, param='input') - - req_params['max_new_tokens'] = max_tokens - req_params['truncation_length'] = truncation_length - req_params['temperature'] = temperature - req_params['top_p'] = top_p - req_params['seed'] = shared.settings.get('seed', req_params['seed']) - req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token']) - req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings'] - - debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count}) - - generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False) - - answer = '' - for a in generator: - answer = a - - # some reply's have an extra leading space to fit the instruction template, just clip it off from the reply. - if edit_task[-1] != '\n' and answer and answer[0] == ' ': - answer = answer[1:] - - completion_token_count = len(encode(answer)[0]) - - resp = { - "object": "edit", - "created": created_time, - "choices": [{ - "text": answer, - "index": 0, - }], - "usage": { - "prompt_tokens": token_count, - "completion_tokens": completion_token_count, - "total_tokens": token_count + completion_token_count - } - } - - return resp diff --git a/extensions/openai/embeddings.py b/extensions/openai/embeddings.py index d6d30721..88ab1c30 100644 --- a/extensions/openai/embeddings.py +++ b/extensions/openai/embeddings.py @@ -6,9 +6,13 @@ from extensions.openai.utils import debug_msg, float_list_to_base64 from sentence_transformers import SentenceTransformer embeddings_params_initialized = False -# using 'lazy loading' to avoid circular import -# so this function will be executed only once + + def initialize_embedding_params(): + ''' + using 'lazy loading' to avoid circular import + so this function will be executed only once + ''' global embeddings_params_initialized if not embeddings_params_initialized: global st_model, embeddings_model, embeddings_device @@ -26,7 +30,7 @@ def load_embedding_model(model: str) -> SentenceTransformer: initialize_embedding_params() global embeddings_device, embeddings_model try: - print(f"\Try embedding model: {model} on {embeddings_device}") + print(f"Try embedding model: {model} on {embeddings_device}") # see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer embeddings_model = SentenceTransformer(model, device=embeddings_device) # ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM @@ -54,7 +58,7 @@ def get_embeddings(input: list) -> np.ndarray: model = get_embeddings_model() debug_msg(f"embedding model : {model}") embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False) - debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will + debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will return embedding diff --git a/extensions/openai/images.py b/extensions/openai/images.py index 350ea617..1c8ea3a0 100644 --- a/extensions/openai/images.py +++ b/extensions/openai/images.py @@ -50,6 +50,7 @@ def generations(prompt: str, size: str, response_format: str, n: int): 'data': [] } from extensions.openai.script import params + # TODO: support SD_WEBUI_AUTH username:password pair. sd_url = f"{os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', ''))}/sdapi/v1/txt2img" diff --git a/extensions/openai/requirements.txt b/extensions/openai/requirements.txt index 8c63b5e1..61ef984f 100644 --- a/extensions/openai/requirements.txt +++ b/extensions/openai/requirements.txt @@ -1,4 +1,5 @@ SpeechRecognition==3.10.0 -flask_cloudflared==0.0.12 +flask_cloudflared==0.0.14 sentence-transformers +sse-starlette==1.6.5 tiktoken diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 72fd1610..20fc94ed 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,351 +1,255 @@ import json import os -import ssl -import traceback -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread import extensions.openai.completions as OAIcompletions -import extensions.openai.edits as OAIedits import extensions.openai.embeddings as OAIembeddings import extensions.openai.images as OAIimages import extensions.openai.models as OAImodels import extensions.openai.moderations as OAImoderations -from extensions.openai.defaults import clamp, default, get_default_req_params -from extensions.openai.errors import ( - InvalidRequestError, - OpenAIError, - ServiceUnavailableError -) -from extensions.openai.tokens import token_count, token_decode, token_encode -from extensions.openai.utils import debug_msg -from modules import shared - -import cgi import speech_recognition as sr +import uvicorn +from extensions.openai.errors import ServiceUnavailableError +from extensions.openai.tokens import token_count, token_decode, token_encode +from extensions.openai.utils import _start_cloudflared +from fastapi import Depends, FastAPI, Header, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.requests import Request +from fastapi.responses import JSONResponse +from modules import shared +from modules.logging_colors import logger from pydub import AudioSegment +from sse_starlette import EventSourceResponse + +from .typing import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + to_dict +) params = { - # default params - 'port': 5001, 'embedding_device': 'cpu', 'embedding_model': 'all-mpnet-base-v2', - - # optional params 'sd_webui_url': '', 'debug': 0 } -class Handler(BaseHTTPRequestHandler): - def send_access_control_headers(self): - self.send_header("Access-Control-Allow-Origin", "*") - self.send_header("Access-Control-Allow-Credentials", "true") - self.send_header( - "Access-Control-Allow-Methods", - "GET,HEAD,OPTIONS,POST,PUT" - ) - self.send_header( - "Access-Control-Allow-Headers", - "Origin, Accept, X-Requested-With, Content-Type, " - "Access-Control-Request-Method, Access-Control-Request-Headers, " - "Authorization" - ) - def do_OPTIONS(self): - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - self.end_headers() - self.wfile.write("OK".encode('utf-8')) +def verify_api_key(authorization: str = Header(None)) -> None: + expected_api_key = shared.args.api_key + if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): + raise HTTPException(status_code=401, detail="Unauthorized") - def start_sse(self): - self.send_response(200) - self.send_access_control_headers() - self.send_header('Content-Type', 'text/event-stream') - self.send_header('Cache-Control', 'no-cache') - # self.send_header('Connection', 'keep-alive') - self.end_headers() - def send_sse(self, chunk: dict): - response = 'data: ' + json.dumps(chunk) + '\r\n\r\n' - debug_msg(response[:-4]) - self.wfile.write(response.encode('utf-8')) +app = FastAPI(dependencies=[Depends(verify_api_key)]) - def end_sse(self): - response = 'data: [DONE]\r\n\r\n' - debug_msg(response[:-4]) - self.wfile.write(response.encode('utf-8')) +# Configure CORS settings to allow all origins, methods, and headers +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "HEAD", "OPTIONS", "POST", "PUT"], + allow_headers=[ + "Origin", + "Accept", + "X-Requested-With", + "Content-Type", + "Access-Control-Request-Method", + "Access-Control-Request-Headers", + "Authorization", + ], +) - def return_json(self, ret: dict, code: int = 200, no_debug=False): - self.send_response(code) - self.send_access_control_headers() - self.send_header('Content-Type', 'application/json') - response = json.dumps(ret) - r_utf8 = response.encode('utf-8') +@app.options("/") +async def options_route(): + return JSONResponse(content="OK") - self.send_header('Content-Length', str(len(r_utf8))) - self.end_headers() - self.wfile.write(r_utf8) - if not no_debug: - debug_msg(r_utf8) +@app.post('/v1/completions', response_model=CompletionResponse) +@app.post('/v1/generate', response_model=CompletionResponse) +async def openai_completions(request: Request, request_data: CompletionRequest): + path = request.url.path + is_legacy = "/generate" in path - def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''): + if request_data.stream: + async def generator(): + response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy) + for resp in response: + yield {"data": json.dumps(resp)} - error_resp = { - 'error': { - 'message': message, - 'code': code, - 'type': error_type, - 'param': param, - } - } - if internal_message: - print(error_type, message) - print(internal_message) - # error_resp['internal_message'] = internal_message + return EventSourceResponse(generator()) # SSE streaming - self.return_json(error_resp, code) + else: + response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy) + return JSONResponse(response) - def openai_error_handler(func): - def wrapper(self): - try: - func(self) - except InvalidRequestError as e: - self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message) - except OpenAIError as e: - self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message) - except Exception as e: - self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc()) - return wrapper +@app.post('/v1/chat/completions', response_model=ChatCompletionResponse) +async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest): + path = request.url.path + is_legacy = "/generate" in path - @openai_error_handler - def do_GET(self): - debug_msg(self.requestline) - debug_msg(self.headers) + if request_data.stream: + async def generator(): + response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy) + for resp in response: + yield {"data": json.dumps(resp)} - if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'): - is_legacy = 'engines' in self.path - is_list = self.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models'] - if is_legacy and not is_list: - model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):] - resp = OAImodels.load_model(model_name) - elif is_list: - resp = OAImodels.list_models(is_legacy) - else: - model_name = self.path[len('/v1/models/'):] - resp = OAImodels.model_info(model_name) + return EventSourceResponse(generator()) # SSE streaming - self.return_json(resp) + else: + response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy) + return JSONResponse(response) - elif '/billing/usage' in self.path: - # Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 - self.return_json({"total_usage": 0}, no_debug=True) - else: - self.send_error(404) +@app.get("/v1/models") +@app.get("/v1/engines") +async def handle_models(request: Request): + path = request.url.path + is_legacy = 'engines' in path + is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models'] - @openai_error_handler - def do_POST(self): + if is_legacy and not is_list: + model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):] + resp = OAImodels.load_model(model_name) + elif is_list: + resp = OAImodels.list_models(is_legacy) + else: + model_name = path[len('/v1/models/'):] + resp = OAImodels.model_info(model_name) - if '/v1/audio/transcriptions' in self.path: - r = sr.Recognizer() + return JSONResponse(content=resp) - # Parse the form data - form = cgi.FieldStorage( - fp=self.rfile, - headers=self.headers, - environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']} - ) - - audio_file = form['file'].file - audio_data = AudioSegment.from_file(audio_file) - - # Convert AudioSegment to raw data - raw_data = audio_data.raw_data - - # Create AudioData object - audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width) - whipser_language = form.getvalue('language', None) - whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny - transcription = {"text": ""} - - try: - transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model) - except sr.UnknownValueError: - print("Whisper could not understand audio") - transcription["text"] = "Whisper could not understand audio UnknownValueError" - except sr.RequestError as e: - print("Could not request results from Whisper", e) - transcription["text"] = "Whisper could not understand audio RequestError" - - self.return_json(transcription, no_debug=True) - return - - debug_msg(self.requestline) - debug_msg(self.headers) +@app.get('/v1/billing/usage') +def handle_billing_usage(): + ''' + Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 + ''' + return JSONResponse(content={"total_usage": 0}) - content_length = self.headers.get('Content-Length') - transfer_encoding = self.headers.get('Transfer-Encoding') - if content_length: - body = json.loads(self.rfile.read(int(content_length)).decode('utf-8')) - elif transfer_encoding == 'chunked': - chunks = [] - while True: - chunk_size = int(self.rfile.readline(), 16) # Read the chunk size - if chunk_size == 0: - break # End of chunks - chunks.append(self.rfile.read(chunk_size)) - self.rfile.readline() # Consume the trailing newline after each chunk - body = json.loads(b''.join(chunks).decode('utf-8')) - else: - self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.") - self.end_headers() - return +@app.post('/v1/audio/transcriptions') +async def handle_audio_transcription(request: Request): + r = sr.Recognizer() - debug_msg(body) + form = await request.form() + audio_file = await form["file"].read() + audio_data = AudioSegment.from_file(audio_file) - if '/completions' in self.path or '/generate' in self.path: + # Convert AudioSegment to raw data + raw_data = audio_data.raw_data - if not shared.model: - raise ServiceUnavailableError("No model loaded.") + # Create AudioData object + audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width) + whipser_language = form.getvalue('language', None) + whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny - is_legacy = '/generate' in self.path - is_streaming = body.get('stream', False) + transcription = {"text": ""} - if is_streaming: - self.start_sse() + try: + transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model) + except sr.UnknownValueError: + print("Whisper could not understand audio") + transcription["text"] = "Whisper could not understand audio UnknownValueError" + except sr.RequestError as e: + print("Could not request results from Whisper", e) + transcription["text"] = "Whisper could not understand audio RequestError" - response = [] - if 'chat' in self.path: - response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy) - else: - response = OAIcompletions.stream_completions(body, is_legacy=is_legacy) + return JSONResponse(content=transcription) - for resp in response: - self.send_sse(resp) - self.end_sse() +@app.post('/v1/images/generations') +async def handle_image_generation(request: Request): - else: - response = '' - if 'chat' in self.path: - response = OAIcompletions.chat_completions(body, is_legacy=is_legacy) - else: - response = OAIcompletions.completions(body, is_legacy=is_legacy) + if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): + raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") - self.return_json(response) + body = await request.json() + prompt = body['prompt'] + size = body.get('size', '1024x1024') + response_format = body.get('response_format', 'url') # or b64_json + n = body.get('n', 1) # ignore the batch limits of max 10 - elif '/edits' in self.path: - # deprecated + response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n) + return JSONResponse(response) - if not shared.model: - raise ServiceUnavailableError("No model loaded.") - req_params = get_default_req_params() +@app.post("/v1/embeddings") +async def handle_embeddings(request: Request): + body = await request.json() + encoding_format = body.get("encoding_format", "") - instruction = body['instruction'] - input = body.get('input', '') - temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0 - top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0) + input = body.get('input', body.get('text', '')) + if not input: + raise HTTPException(status_code=400, detail="Missing required argument input") - response = OAIedits.edits(instruction, input, temperature, top_p) + if type(input) is str: + input = [input] - self.return_json(response) + response = OAIembeddings.embeddings(input, encoding_format) + return JSONResponse(response) - elif '/images/generations' in self.path: - if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): - raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") - prompt = body['prompt'] - size = default(body, 'size', '1024x1024') - response_format = default(body, 'response_format', 'url') # or b64_json - n = default(body, 'n', 1) # ignore the batch limits of max 10 +@app.post("/v1/moderations") +async def handle_moderations(request: Request): + body = await request.json() + input = body["input"] + if not input: + raise HTTPException(status_code=400, detail="Missing required argument input") - response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n) + response = OAImoderations.moderations(input) + return JSONResponse(response) - self.return_json(response, no_debug=True) - elif '/embeddings' in self.path: - encoding_format = body.get('encoding_format', '') +@app.post("/api/v1/token-count") +async def handle_token_count(request: Request): + body = await request.json() + response = token_count(body['prompt']) + return JSONResponse(response) - input = body.get('input', body.get('text', '')) - if not input: - raise InvalidRequestError("Missing required argument input", params='input') - if type(input) is str: - input = [input] +@app.post("/api/v1/token/encode") +async def handle_token_encode(request: Request): + body = await request.json() + encoding_format = body.get("encoding_format", "") + response = token_encode(body["input"], encoding_format) + return JSONResponse(response) - response = OAIembeddings.embeddings(input, encoding_format) - self.return_json(response, no_debug=True) - - elif '/moderations' in self.path: - input = body['input'] - if not input: - raise InvalidRequestError("Missing required argument input", params='input') - - response = OAImoderations.moderations(input) - - self.return_json(response, no_debug=True) - - elif self.path == '/api/v1/token-count': - # NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side. - response = token_count(body['prompt']) - - self.return_json(response, no_debug=True) - - elif self.path == '/api/v1/token/encode': - # NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models - encoding_format = body.get('encoding_format', '') - - response = token_encode(body['input'], encoding_format) - - self.return_json(response, no_debug=True) - - elif self.path == '/api/v1/token/decode': - # NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models - encoding_format = body.get('encoding_format', '') - - response = token_decode(body['input'], encoding_format) - - self.return_json(response, no_debug=True) - - else: - self.send_error(404) +@app.post("/api/v1/token/decode") +async def handle_token_decode(request: Request): + body = await request.json() + encoding_format = body.get("encoding_format", "") + response = token_decode(body["input"], encoding_format) + return JSONResponse(response, no_debug=True) def run_server(): - port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001))) - server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port) - server = ThreadingHTTPServer(server_addr, Handler) - - ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile) - ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile) - ssl_verify=True if (ssl_keyfile and ssl_certfile) else False - if ssl_verify: - context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - context.load_cert_chain(ssl_certfile, ssl_keyfile) - server.socket = context.wrap_socket(server.socket, server_side=True) - - if shared.args.share: - try: - from flask_cloudflared import _run_cloudflared - public_url = _run_cloudflared(port, port + 1) - print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1') - except ImportError: - print('You should install flask_cloudflared manually') + server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1' + port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port)) + + ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile) + ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile) + + if shared.args.public_api: + def on_start(public_url: str): + logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n') + + _start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start) else: - if ssl_verify: - print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1') + if ssl_keyfile and ssl_certfile: + logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n') else: - print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1') - - server.serve_forever() + logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n') + + if shared.args.api_key: + logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n') + + uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) def setup(): diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py new file mode 100644 index 00000000..2f249879 --- /dev/null +++ b/extensions/openai/typing.py @@ -0,0 +1,125 @@ +import json +import time +from typing import List + +from pydantic import BaseModel, Field + + +class GenerationOptions(BaseModel): + preset: str | None = None + temperature: float = 1 + temperature_last: bool = False + top_p: float = 1 + min_p: float = 0 + top_k: int = 0 + repetition_penalty: float = 1 + presence_penalty: float = 0 + frequency_penalty: float = 0 + repetition_penalty_range: int = 0 + typical_p: float = 1 + tfs: float = 1 + top_a: float = 0 + epsilon_cutoff: float = 0 + eta_cutoff: float = 0 + guidance_scale: float = 1 + negative_prompt: str = '' + penalty_alpha: float = 0 + mirostat_mode: int = 0 + mirostat_tau: float = 5 + mirostat_eta: float = 0.1 + do_sample: bool = True + seed: int = -1 + encoder_repetition_penalty: float = 1 + no_repeat_ngram_size: int = 0 + min_length: int = 0 + num_beams: int = 1 + length_penalty: float = 1 + early_stopping: bool = False + truncation_length: int = 0 + max_tokens_second: int = 0 + custom_token_bans: str = "" + auto_max_new_tokens: bool = False + ban_eos_token: bool = False + add_bos_token: bool = True + skip_special_tokens: bool = True + grammar_string: str = "" + + +class CompletionRequest(GenerationOptions): + model: str | None = None + prompt: str | List[str] + best_of: int | None = 1 + echo: bool | None = False + frequency_penalty: float | None = 0 + logit_bias: dict | None = None + logprobs: int | None = None + max_tokens: int | None = 16 + n: int | None = 1 + presence_penalty: int | None = 0 + stop: str | List[str] | None = None + stream: bool | None = False + suffix: str | None = None + temperature: float | None = 1 + top_p: float | None = 1 + user: str | None = None + + +class CompletionResponse(BaseModel): + id: str + choices: List[dict] + created: int = int(time.time()) + model: str + object: str = "text_completion" + usage: dict + + +class ChatCompletionRequest(GenerationOptions): + messages: List[dict] + model: str | None = None + frequency_penalty: float | None = 0 + function_call: str | dict | None = None + functions: List[dict] | None = None + logit_bias: dict | None = None + max_tokens: int | None = None + n: int | None = 1 + presence_penalty: int | None = 0 + stop: str | List[str] | None = None + stream: bool | None = False + temperature: float | None = 1 + top_p: float | None = 1 + user: str | None = None + + mode: str = Field(default='instruct', description="Valid options: instruct, chat, chat-instruct.") + + instruction_template: str | None = Field(default=None, description="An instruction template defined under text-generation-webui/instruction-templates. If not set, the correct template will be guessed using the regex expressions in models/config.yaml.") + name1_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.") + name2_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.") + context_instruct: str | None = Field(default=None, description="Overwrites the value set by instruction_template.") + turn_template: str | None = Field(default=None, description="Overwrites the value set by instruction_template.") + + character: str | None = Field(default=None, description="A character defined under text-generation-webui/characters. If not set, the default \"Assistant\" character will be used.") + name1: str | None = Field(default=None, description="Overwrites the value set by character.") + name2: str | None = Field(default=None, description="Overwrites the value set by character.") + context: str | None = Field(default=None, description="Overwrites the value set by character.") + greeting: str | None = Field(default=None, description="Overwrites the value set by character.") + + chat_instruct_command: str | None = None + + continue_: bool = Field(default=False, description="Makes the last bot message in the history be continued instead of starting a new message.") + + +class ChatCompletionResponse(BaseModel): + id: str + choices: List[dict] + created: int = int(time.time()) + model: str + object: str = "chat.completion" + usage: dict + + +def to_json(obj): + return json.dumps(obj.__dict__, indent=4) + + +def to_dict(obj): + return obj.__dict__ diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py index 49fc9510..2b414769 100644 --- a/extensions/openai/utils.py +++ b/extensions/openai/utils.py @@ -1,8 +1,12 @@ import base64 import os +import time +import traceback +from typing import Callable, Optional import numpy as np + def float_list_to_base64(float_array: np.ndarray) -> str: # Convert the list to a float32 array that the OpenAPI client expects # float_array = np.array(float_list, dtype="float32") @@ -18,13 +22,33 @@ def float_list_to_base64(float_array: np.ndarray) -> str: return ascii_string -def end_line(s): - if s and s[-1] != '\n': - s = s + '\n' - return s - - def debug_msg(*args, **kwargs): from extensions.openai.script import params if os.environ.get("OPENEDAI_DEBUG", params.get('debug', 0)): print(*args, **kwargs) + + +def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): + try: + from flask_cloudflared import _run_cloudflared + except ImportError: + print('You should install flask_cloudflared manually') + raise Exception( + 'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.') + + for _ in range(max_attempts): + try: + if tunnel_id is not None: + public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id) + else: + public_url = _run_cloudflared(port, port + 1) + + if on_start: + on_start(public_url) + + return + except Exception: + traceback.print_exc() + time.sleep(3) + + raise Exception('Could not start cloudflared.') diff --git a/modules/chat.py b/modules/chat.py index 334693ab..82976479 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -81,7 +81,7 @@ def generate_chat_prompt(user_input, state, **kwargs): # Find the maximum prompt size max_length = get_max_prompt_length(state) all_substrings = { - 'chat': get_turn_substrings(state, instruct=False), + 'chat': get_turn_substrings(state, instruct=False) if state['mode'] in ['chat', 'chat-instruct'] else None, 'instruct': get_turn_substrings(state, instruct=True) } @@ -237,7 +237,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)): # Extract the reply - visible_reply = re.sub("(||{{user}})", state['name1'], reply) + visible_reply = reply + if state['mode'] in ['chat', 'chat-instruct']: + visible_reply = re.sub("(||{{user}})", state['name1'], reply) + visible_reply = html.escape(visible_reply) if shared.stop_everything: diff --git a/modules/models.py b/modules/models.py index e9005fee..d0392485 100644 --- a/modules/models.py +++ b/modules/models.py @@ -71,11 +71,12 @@ def load_model(model_name, loader=None): 'AutoAWQ': AutoAWQ_loader, } + metadata = get_model_metadata(model_name) if loader is None: if shared.args.loader is not None: loader = shared.args.loader else: - loader = get_model_metadata(model_name)['loader'] + loader = metadata['loader'] if loader is None: logger.error('The path to the model does not exist. Exiting.') return None, None @@ -95,6 +96,7 @@ def load_model(model_name, loader=None): if any((shared.args.xformers, shared.args.sdp_attention)): llama_attn_hijack.hijack_llama_attention() + shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings}) logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer diff --git a/modules/presets.py b/modules/presets.py index 62c2c90c..5082678b 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -6,33 +6,32 @@ import yaml def default_preset(): return { - 'do_sample': True, 'temperature': 1, 'temperature_last': False, 'top_p': 1, 'min_p': 0, 'top_k': 0, - 'typical_p': 1, - 'epsilon_cutoff': 0, - 'eta_cutoff': 0, - 'tfs': 1, - 'top_a': 0, 'repetition_penalty': 1, 'presence_penalty': 0, 'frequency_penalty': 0, 'repetition_penalty_range': 0, + 'typical_p': 1, + 'tfs': 1, + 'top_a': 0, + 'epsilon_cutoff': 0, + 'eta_cutoff': 0, + 'guidance_scale': 1, + 'penalty_alpha': 0, + 'mirostat_mode': 0, + 'mirostat_tau': 5, + 'mirostat_eta': 0.1, + 'do_sample': True, 'encoder_repetition_penalty': 1, 'no_repeat_ngram_size': 0, 'min_length': 0, - 'guidance_scale': 1, - 'mirostat_mode': 0, - 'mirostat_tau': 5.0, - 'mirostat_eta': 0.1, - 'penalty_alpha': 0, 'num_beams': 1, 'length_penalty': 1, 'early_stopping': False, - 'custom_token_bans': '', } diff --git a/modules/shared.py b/modules/shared.py index 1dd6841d..c9cd385b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -39,21 +39,21 @@ settings = { 'max_new_tokens': 200, 'max_new_tokens_min': 1, 'max_new_tokens_max': 4096, - 'seed': -1, 'negative_prompt': '', + 'seed': -1, 'truncation_length': 2048, 'truncation_length_min': 0, 'truncation_length_max': 32768, - 'custom_stopping_strings': '', - 'auto_max_new_tokens': False, 'max_tokens_second': 0, - 'ban_eos_token': False, + 'custom_stopping_strings': '', 'custom_token_bans': '', + 'auto_max_new_tokens': False, + 'ban_eos_token': False, 'add_bos_token': True, 'skip_special_tokens': True, 'stream': True, - 'name1': 'You', 'character': 'Assistant', + 'name1': 'You', 'instruction_template': 'Alpaca', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>', 'autoload_model': False, @@ -167,8 +167,8 @@ parser.add_argument('--ssl-certfile', type=str, help='The path to the SSL certif parser.add_argument('--api', action='store_true', help='Enable the API extension.') parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) -parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.') -parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.') +parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.') +parser.add_argument('--api-key', type=str, default='', help='API authentication key.') # Multimodal parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.') @@ -178,6 +178,8 @@ parser.add_argument('--notebook', action='store_true', help='DEPRECATED') parser.add_argument('--chat', action='store_true', help='DEPRECATED') parser.add_argument('--no-stream', action='store_true', help='DEPRECATED') parser.add_argument('--mul_mat_q', action='store_true', help='DEPRECATED') +parser.add_argument('--api-blocking-port', type=int, default=5000, help='DEPRECATED') +parser.add_argument('--api-streaming-port', type=int, default=5005, help='DEPRECATED') args = parser.parse_args() args_defaults = parser.parse_args([]) @@ -233,10 +235,13 @@ def fix_loader_name(name): return 'AutoAWQ' -def add_extension(name): +def add_extension(name, last=False): if args.extensions is None: args.extensions = [name] - elif 'api' not in args.extensions: + elif last: + args.extensions = [x for x in args.extensions if x != name] + args.extensions.append(name) + elif name not in args.extensions: args.extensions.append(name) @@ -246,14 +251,15 @@ def is_chat(): args.loader = fix_loader_name(args.loader) -# Activate the API extension -if args.api or args.public_api: - add_extension('api') - # Activate the multimodal extension if args.multimodal_pipeline is not None: add_extension('multimodal') +# Activate the API extension +if args.api: + # add_extension('openai', last=True) + add_extension('api', last=True) + # Load model-specific settings with Path(f'{args.model_dir}/config.yaml') as p: if p.exists(): diff --git a/modules/text_generation.py b/modules/text_generation.py index e2efa41d..310525d2 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -56,7 +56,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap # Find the stopping strings all_stop_strings = [] - for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): + for st in (stopping_strings, state['custom_stopping_strings']): + if type(st) is str: + st = ast.literal_eval(f"[{st}]") + if type(st) is list and len(st) > 0: all_stop_strings += st diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 0d82ee8f..588386ac 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -215,9 +215,6 @@ def load_model_wrapper(selected_model, loader, autoload=False): if 'instruction_template' in settings: output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template']) - # Applying the changes to the global shared settings (in-memory) - shared.settings.update({k: v for k, v in settings.items() if k in shared.settings}) - yield output else: yield f"Failed to load `{selected_model}`."