diff --git a/.gitignore b/.gitignore index 80244a30..3cfbbb22 100644 --- a/.gitignore +++ b/.gitignore @@ -4,12 +4,15 @@ extensions/silero_tts/outputs/* extensions/elevenlabs_tts/outputs/* extensions/sd_api_pictures/outputs/* logs/* +loras/* models/* softprompts/* torch-dumps/* *pycache* */*pycache* */*/pycache* +venv/ +.venv/ settings.json img_bot* @@ -17,6 +20,7 @@ img_me* !characters/Example.json !characters/Example.png +!loras/place-your-loras-here.txt !models/place-your-models-here.txt !softprompts/place-your-softprompts-here.txt !torch-dumps/place-your-pt-models-here.txt diff --git a/README.md b/README.md index c9834558..ded9b351 100644 --- a/README.md +++ b/README.md @@ -19,52 +19,76 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * Generate Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX support. * Support for [Pygmalion](https://huggingface.co/models?search=pygmalionai/pygmalion) and custom characters in JSON or TavernAI Character Card formats ([FAQ](https://github.com/oobabooga/text-generation-webui/wiki/Pygmalion-chat-model-FAQ)). * Advanced chat features (send images, get audio responses with TTS). -* Stream the text output in real time. +* Stream the text output in real time very efficiently. * Load parameter presets from text files. -* Load large models in 8-bit mode (see [here](https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134), [here](https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652) and [here](https://www.reddit.com/r/PygmalionAI/comments/1115gom/running_pygmalion_6b_with_8gb_of_vram/) if you are on Windows). +* Load large models in 8-bit mode. * Split large models across your GPU(s), CPU, and disk. * CPU mode. * [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen). * [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed). * Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming. -* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). -* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). +* [LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model). +* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model). +* [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs). * Supports softprompts. * [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions). * [Works on Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab). -## Installation option 1: conda +## Installation -Open a terminal and copy and paste these commands one at a time ([install conda](https://docs.conda.io/en/latest/miniconda.html) first if you don't have it already): +The recommended installation methods are the following: + +* Linux and MacOS: using conda natively. +* Windows: using conda on WSL ([WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide)). + +Conda can be downloaded here: https://docs.conda.io/en/latest/miniconda.html + +On Linux or WSL, it can be automatically installed with these two commands: ``` -conda create -n textgen +curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh" +bash Miniconda3.sh +``` + +Source: https://educe-ubc.github.io/conda.html + +#### 1. Create a new conda environment + +``` +conda create -n textgen python=3.10.9 conda activate textgen -conda install torchvision torchaudio pytorch-cuda=11.7 git -c pytorch -c nvidia +``` + +#### 2. Install Pytorch + +| System | GPU | Command | +|--------|---------|---------| +| Linux/WSL | NVIDIA | `pip3 install torch torchvision torchaudio` | +| Linux | AMD | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2` | +| MacOS + MPS (untested) | Any | `pip3 install torch torchvision torchaudio` | + +The up to date commands can be found here: https://pytorch.org/get-started/locally/. + +MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393 + + +#### 3. Install the web UI + +``` git clone https://github.com/oobabooga/text-generation-webui cd text-generation-webui pip install -r requirements.txt ``` -The third line assumes that you have an NVIDIA GPU. - -* If you have an AMD GPU, replace the third command with this one: - -``` -pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2 -``` - -* If you are running it in CPU mode, replace the third command with this one: - -``` -conda install pytorch torchvision torchaudio git -c pytorch -``` - > **Note** -> 1. If you are on Windows, it may be easier to run the commands above in a WSL environment. The performance may also be better. -> 2. For a more detailed, user-contributed guide, see: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings). +> +> For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859 -## Installation option 2: one-click installers +### Alternative: native Windows installation + +As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings). + +### Alternative: one-click installers [oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip) @@ -75,19 +99,25 @@ Just download the zip above, extract it, and double click on "install". The web * To download a model, double click on "download-model" * To start the web UI, double click on "start-webui" +Source codes: https://github.com/oobabooga/one-click-installers + +This method lags behind the newest developments and does not support 8-bit mode on Windows without additional set up: https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134, https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652 + +### Alternative: Docker + +https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87 + ## Downloading models -Models should be placed under `models/model-name`. For instance, `models/gpt-j-6B` for [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main). - -#### Hugging Face +Models should be placed inside the `models` folder. [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) is the main place to download models. These are some noteworthy examples: -* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main) -* [GPT-Neo](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=eleutherai+%2F+gpt-neo) * [Pythia](https://huggingface.co/models?search=eleutherai/pythia) * [OPT](https://huggingface.co/models?search=facebook/opt) * [GALACTICA](https://huggingface.co/models?search=facebook/galactica) +* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main) +* [GPT-Neo](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=eleutherai+%2F+gpt-neo) * [\*-Erebus](https://huggingface.co/models?search=erebus) (NSFW) * [Pygmalion](https://huggingface.co/models?search=pygmalion) (NSFW) @@ -101,7 +131,7 @@ For instance: If you want to download a model manually, note that all you need are the json, txt, and pytorch\*.bin (or model*.safetensors) files. The remaining files are not necessary. -#### GPT-4chan +### GPT-4chan [GPT-4chan](https://huggingface.co/ykilcher/gpt-4chan) has been shut down from Hugging Face, so you need to download it elsewhere. You have two options: @@ -123,6 +153,7 @@ python download-model.py EleutherAI/gpt-j-6B --text-only ## Starting the web UI conda activate textgen + cd text-generation-webui python server.py Then browse to @@ -133,41 +164,42 @@ Then browse to Optionally, you can use the following command-line flags: -| Flag | Description | -|-------------|-------------| -| `-h`, `--help` | show this help message and exit | -| `--model MODEL` | Name of the model to load by default. | -| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | -| `--chat` | Launch the web UI in chat mode.| -| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | -| `--cpu` | Use the CPU to generate text.| -| `--load-in-8bit` | Load the model with 8-bit precision.| -| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. | -| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. | -| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. | -| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | +| Flag | Description | +|------------------|-------------| +| `-h`, `--help` | show this help message and exit | +| `--model MODEL` | Name of the model to load by default. | +| `--lora LORA` | Name of the LoRA to apply to the model by default. | +| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | +| `--chat` | Launch the web UI in chat mode.| +| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | +| `--cpu` | Use the CPU to generate text.| +| `--load-in-8bit` | Load the model with 8-bit precision.| +| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. | +| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. | +| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. | +| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. | | `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.| -| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | +| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. | | `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. | | `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. | -| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.| -| `--flexgen` | Enable the use of FlexGen offloading. | -| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). | -| `--compress-weight` | FlexGen: Whether to compress weight (default: False).| -| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). | +| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.| +| `--flexgen` | Enable the use of FlexGen offloading. | +| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). | +| `--compress-weight` | FlexGen: Whether to compress weight (default: False).| +| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). | | `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. | -| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. | -| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. | -| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". | -| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | -| `--no-stream` | Don't stream the text output in real time. | +| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. | +| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. | +| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". | +| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | +| `--no-stream` | Don't stream the text output in real time. | | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.| | `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | -| `--listen` | Make the web UI reachable from your local network.| +| `--listen` | Make the web UI reachable from your local network.| | `--listen-port LISTEN_PORT` | The listening port that the server will use. | -| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | -| `--auto-launch` | Open the web UI in the default browser upon launch. | -| `--verbose` | Print the prompts to the terminal. | +| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | +| `--auto-launch` | Open the web UI in the default browser upon launch. | +| `--verbose` | Print the prompts to the terminal. | Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide). @@ -192,7 +224,7 @@ Before reporting a bug, make sure that you have: ## Credits -- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui +- Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui - Verbose preset: Anonymous 4chan user. - NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets - Pygmalion preset, code for early stopping in chat mode, code for some of the sliders, --chat mode colors: https://github.com/PygmalionAI/gradio-ui/ diff --git a/api-example-stream.py b/api-example-stream.py index a5ed4202..055d605b 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -26,6 +26,7 @@ async def run(context): 'top_p': 0.9, 'typical_p': 1, 'repetition_penalty': 1.05, + 'encoder_repetition_penalty': 1.0, 'top_k': 0, 'min_length': 0, 'no_repeat_ngram_size': 0, @@ -43,14 +44,14 @@ async def run(context): case "send_hash": await websocket.send(json.dumps({ "session_hash": session, - "fn_index": 7 + "fn_index": 9 })) case "estimation": pass case "send_data": await websocket.send(json.dumps({ "session_hash": session, - "fn_index": 7, + "fn_index": 9, "data": [ context, params['max_new_tokens'], @@ -59,6 +60,7 @@ async def run(context): params['top_p'], params['typical_p'], params['repetition_penalty'], + params['encoder_repetition_penalty'], params['top_k'], params['min_length'], params['no_repeat_ngram_size'], diff --git a/api-example.py b/api-example.py index 0306b7ab..a6f0c10e 100644 --- a/api-example.py +++ b/api-example.py @@ -24,6 +24,7 @@ params = { 'top_p': 0.9, 'typical_p': 1, 'repetition_penalty': 1.05, + 'encoder_repetition_penalty': 1.0, 'top_k': 0, 'min_length': 0, 'no_repeat_ngram_size': 0, @@ -45,6 +46,7 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={ params['top_p'], params['typical_p'], params['repetition_penalty'], + params['encoder_repetition_penalty'], params['top_k'], params['min_length'], params['no_repeat_ngram_size'], diff --git a/css/chat.css b/css/chat.css new file mode 100644 index 00000000..8d9d88a6 --- /dev/null +++ b/css/chat.css @@ -0,0 +1,25 @@ +.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx { + height: 66.67vh +} + +.gradio-container { + margin-left: auto !important; + margin-right: auto !important; +} + +.w-screen { + width: unset +} + +div.svelte-362y77>*, div.svelte-362y77>.form>* { + flex-wrap: nowrap +} + +/* fixes the API documentation in chat mode */ +.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h { + display: grid; +} + +.pending.svelte-1ed2p3z { + opacity: 1; +} diff --git a/css/chat.js b/css/chat.js new file mode 100644 index 00000000..e304f125 --- /dev/null +++ b/css/chat.js @@ -0,0 +1,4 @@ +document.getElementById("main").childNodes[0].style = "max-width: 800px; margin-left: auto; margin-right: auto"; +document.getElementById("extensions").style.setProperty("max-width", "800px"); +document.getElementById("extensions").style.setProperty("margin-left", "auto"); +document.getElementById("extensions").style.setProperty("margin-right", "auto"); diff --git a/css/html_4chan_style.css b/css/html_4chan_style.css new file mode 100644 index 00000000..843e8a97 --- /dev/null +++ b/css/html_4chan_style.css @@ -0,0 +1,103 @@ +#parent #container { + background-color: #eef2ff; + padding: 17px; +} +#parent #container .reply { + background-color: rgb(214, 218, 240); + border-bottom-color: rgb(183, 197, 217); + border-bottom-style: solid; + border-bottom-width: 1px; + border-image-outset: 0; + border-image-repeat: stretch; + border-image-slice: 100%; + border-image-source: none; + border-image-width: 1; + border-left-color: rgb(0, 0, 0); + border-left-style: none; + border-left-width: 0px; + border-right-color: rgb(183, 197, 217); + border-right-style: solid; + border-right-width: 1px; + border-top-color: rgb(0, 0, 0); + border-top-style: none; + border-top-width: 0px; + color: rgb(0, 0, 0); + display: table; + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + margin-bottom: 4px; + margin-left: 0px; + margin-right: 0px; + margin-top: 4px; + overflow-x: hidden; + overflow-y: hidden; + padding-bottom: 4px; + padding-left: 2px; + padding-right: 2px; + padding-top: 4px; +} + +#parent #container .number { + color: rgb(0, 0, 0); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + width: 342.65px; + margin-right: 7px; +} + +#parent #container .op { + color: rgb(0, 0, 0); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + margin-bottom: 8px; + margin-left: 0px; + margin-right: 0px; + margin-top: 4px; + overflow-x: hidden; + overflow-y: hidden; +} + +#parent #container .op blockquote { + margin-left: 0px !important; +} + +#parent #container .name { + color: rgb(17, 119, 67); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + font-weight: 700; + margin-left: 7px; +} + +#parent #container .quote { + color: rgb(221, 0, 0); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; + text-decoration-color: rgb(221, 0, 0); + text-decoration-line: underline; + text-decoration-style: solid; + text-decoration-thickness: auto; +} + +#parent #container .greentext { + color: rgb(120, 153, 34); + font-family: arial, helvetica, sans-serif; + font-size: 13.3333px; +} + +#parent #container blockquote { + margin: 0px !important; + margin-block-start: 1em; + margin-block-end: 1em; + margin-inline-start: 40px; + margin-inline-end: 40px; + margin-top: 13.33px !important; + margin-bottom: 13.33px !important; + margin-left: 40px !important; + margin-right: 40px !important; +} + +#parent #container .message { + color: black; + border: none; +} \ No newline at end of file diff --git a/css/html_cai_style.css b/css/html_cai_style.css new file mode 100644 index 00000000..3190b3d1 --- /dev/null +++ b/css/html_cai_style.css @@ -0,0 +1,73 @@ +.chat { + margin-left: auto; + margin-right: auto; + max-width: 800px; + height: 66.67vh; + overflow-y: auto; + padding-right: 20px; + display: flex; + flex-direction: column-reverse; +} + +.message { + display: grid; + grid-template-columns: 60px 1fr; + padding-bottom: 25px; + font-size: 15px; + font-family: Helvetica, Arial, sans-serif; + line-height: 1.428571429; +} + +.circle-you { + width: 50px; + height: 50px; + background-color: rgb(238, 78, 59); + border-radius: 50%; +} + +.circle-bot { + width: 50px; + height: 50px; + background-color: rgb(59, 78, 244); + border-radius: 50%; +} + +.circle-bot img, +.circle-you img { + border-radius: 50%; + width: 100%; + height: 100%; + object-fit: cover; +} + +.text {} + +.text p { + margin-top: 5px; +} + +.username { + font-weight: bold; +} + +.message-body {} + +.message-body img { + max-width: 300px; + max-height: 300px; + border-radius: 20px; +} + +.message-body p { + margin-bottom: 0 !important; + font-size: 15px !important; + line-height: 1.428571429 !important; +} + +.dark .message-body p em { + color: rgb(138, 138, 138) !important; +} + +.message-body p em { + color: rgb(110, 110, 110) !important; +} \ No newline at end of file diff --git a/css/html_readable_style.css b/css/html_readable_style.css new file mode 100644 index 00000000..d3f580a5 --- /dev/null +++ b/css/html_readable_style.css @@ -0,0 +1,14 @@ +.container { + max-width: 600px; + margin-left: auto; + margin-right: auto; + background-color: rgb(31, 41, 55); + padding:3em; +} + +.container p { + font-size: 16px !important; + color: white !important; + margin-bottom: 22px; + line-height: 1.4 !important; +} diff --git a/css/main.css b/css/main.css new file mode 100644 index 00000000..c6b0b07e --- /dev/null +++ b/css/main.css @@ -0,0 +1,52 @@ +.tabs.svelte-710i53 { + margin-top: 0 +} + +.py-6 { + padding-top: 2.5rem +} + +.dark #refresh-button { + background-color: #ffffff1f; +} + +#refresh-button { + flex: none; + margin: 0; + padding: 0; + min-width: 50px; + border: none; + box-shadow: none; + border-radius: 10px; + background-color: #0000000d; +} + +#download-label, #upload-label { + min-height: 0 +} + +#accordion { +} + +.dark svg { + fill: white; +} + +.dark a { + color: white !important; + text-decoration: none !important; +} + +svg { + display: unset !important; + vertical-align: middle !important; + margin: 5px; +} + +ol li p, ul li p { + display: inline-block; +} + +#main, #parameters, #chat-settings, #interface-mode, #lora { + border: 0; +} diff --git a/css/main.js b/css/main.js new file mode 100644 index 00000000..9db3fe8b --- /dev/null +++ b/css/main.js @@ -0,0 +1,18 @@ +document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px"; +document.getElementById("main").parentNode.style = "padding: 0; margin: 0"; +document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0"; + +// Get references to the elements +let main = document.getElementById('main'); +let main_parent = main.parentNode; +let extensions = document.getElementById('extensions'); + +// Add an event listener to the main element +main_parent.addEventListener('click', function(e) { + // Check if the main element is visible + if (main.offsetHeight > 0 && main.offsetWidth > 0) { + extensions.style.display = 'block'; + } else { + extensions.style.display = 'none'; + } +}); diff --git a/download-model.py b/download-model.py index 8be398c4..808b9fc2 100644 --- a/download-model.py +++ b/download-model.py @@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch): classifications = [] has_pytorch = False has_safetensors = False + is_lora = False while True: content = requests.get(f"{base}{page}{cursor.decode()}").content @@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch): for i in range(len(dict)): fname = dict[i]['path'] + if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')): + is_lora = True - is_pytorch = re.match("pytorch_model.*\.bin", fname) + is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname) is_safetensors = re.match("model.*\.safetensors", fname) is_tokenizer = re.match("tokenizer.*\.model", fname) is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer @@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch): has_pytorch = True classifications.append('pytorch') + cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(cursor) cursor = cursor.replace(b'=', b'%3D') @@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch): if classifications[i] == 'pytorch': links.pop(i) - return links + return links, is_lora if __name__ == '__main__': model = args.MODEL @@ -159,15 +163,16 @@ if __name__ == '__main__': except ValueError as err_branch: print(f"Error: {err_branch}") sys.exit() + + links, is_lora = get_download_links_from_huggingface(model, branch) + base_folder = 'models' if not is_lora else 'loras' if branch != 'main': - output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}') + output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}') else: - output_folder = Path("models") / model.split('/')[-1] + output_folder = Path(base_folder) / model.split('/')[-1] if not output_folder.exists(): output_folder.mkdir() - links = get_download_links_from_huggingface(model, branch) - # Downloading the files print(f"Downloading the model to {output_folder}") pool = multiprocessing.Pool(processes=args.threads) diff --git a/extensions/api/requirements.txt b/extensions/api/requirements.txt new file mode 100644 index 00000000..ad788ab8 --- /dev/null +++ b/extensions/api/requirements.txt @@ -0,0 +1 @@ +flask_cloudflared==0.0.12 \ No newline at end of file diff --git a/extensions/api/script.py b/extensions/api/script.py new file mode 100644 index 00000000..53e47f3f --- /dev/null +++ b/extensions/api/script.py @@ -0,0 +1,90 @@ +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from threading import Thread +from modules import shared +from modules.text_generation import generate_reply, encode +import json + +params = { + 'port': 5000, +} + +class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == '/api/v1/model': + self.send_response(200) + self.end_headers() + response = json.dumps({ + 'result': shared.model_name + }) + + self.wfile.write(response.encode('utf-8')) + else: + self.send_error(404) + + def do_POST(self): + content_length = int(self.headers['Content-Length']) + body = json.loads(self.rfile.read(content_length).decode('utf-8')) + + if self.path == '/api/v1/generate': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + prompt = body['prompt'] + prompt_lines = [l.strip() for l in prompt.split('\n')] + + max_context = body.get('max_context_length', 2048) + + while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context: + prompt_lines.pop(0) + + prompt = '\n'.join(prompt_lines) + + generator = generate_reply( + question = prompt, + max_new_tokens = body.get('max_length', 200), + do_sample=True, + temperature=body.get('temperature', 0.5), + top_p=body.get('top_p', 1), + typical_p=body.get('typical', 1), + repetition_penalty=body.get('rep_pen', 1.1), + encoder_repetition_penalty=1, + top_k=body.get('top_k', 0), + min_length=0, + no_repeat_ngram_size=0, + num_beams=1, + penalty_alpha=0, + length_penalty=1, + early_stopping=False, + ) + + answer = '' + for a in generator: + answer = a[0] + + response = json.dumps({ + 'results': [{ + 'text': answer[len(prompt):] + }] + }) + self.wfile.write(response.encode('utf-8')) + else: + self.send_error(404) + + +def run_server(): + server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) + server = ThreadingHTTPServer(server_addr, Handler) + if shared.args.share: + try: + from flask_cloudflared import _run_cloudflared + public_url = _run_cloudflared(params['port'], params['port'] + 1) + print(f'Starting KoboldAI compatible api at {public_url}/api') + except ImportError: + print('You should install flask_cloudflared manually') + else: + print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') + server.serve_forever() + +def ui(): + Thread(target=run_server, daemon=True).start() \ No newline at end of file diff --git a/extensions/elevenlabs_tts/script.py b/extensions/elevenlabs_tts/script.py index 90d61efc..b8171063 100644 --- a/extensions/elevenlabs_tts/script.py +++ b/extensions/elevenlabs_tts/script.py @@ -1,8 +1,8 @@ from pathlib import Path import gradio as gr -from elevenlabslib import * -from elevenlabslib.helpers import * +from elevenlabslib import ElevenLabsUser +from elevenlabslib.helpers import save_bytes_to_path params = { 'activate': True, diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index 8a2d7cf9..fbf23bc9 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -76,7 +76,7 @@ def generate_html(): return container_html def ui(): - with gr.Accordion("Character gallery"): + with gr.Accordion("Character gallery", open=False): update = gr.Button("Refresh") gallery = gr.HTML(value=generate_html()) update.click(generate_html, [], gallery) diff --git a/extensions/whisper_stt/requirements.txt b/extensions/whisper_stt/requirements.txt new file mode 100644 index 00000000..770c38bb --- /dev/null +++ b/extensions/whisper_stt/requirements.txt @@ -0,0 +1,4 @@ +git+https://github.com/Uberi/speech_recognition.git@010382b +openai-whisper +soundfile +ffmpeg diff --git a/extensions/whisper_stt/script.py b/extensions/whisper_stt/script.py new file mode 100644 index 00000000..6ef60c57 --- /dev/null +++ b/extensions/whisper_stt/script.py @@ -0,0 +1,54 @@ +import gradio as gr +import speech_recognition as sr + +input_hijack = { + 'state': False, + 'value': ["", ""] +} + + +def do_stt(audio, text_state=""): + transcription = "" + r = sr.Recognizer() + + # Convert to AudioData + audio_data = sr.AudioData(sample_rate=audio[0], frame_data=audio[1], sample_width=4) + + try: + transcription = r.recognize_whisper(audio_data, language="english", model="base.en") + except sr.UnknownValueError: + print("Whisper could not understand audio") + except sr.RequestError as e: + print("Could not request results from Whisper", e) + + input_hijack.update({"state": True, "value": [transcription, transcription]}) + + text_state += transcription + " " + return text_state, text_state + + +def update_hijack(val): + input_hijack.update({"state": True, "value": [val, val]}) + return val + + +def auto_transcribe(audio, audio_auto, text_state=""): + if audio is None: + return "", "" + if audio_auto: + return do_stt(audio, text_state) + return "", "" + + +def ui(): + tr_state = gr.State(value="") + output_transcription = gr.Textbox(label="STT-Input", + placeholder="Speech Preview. Click \"Generate\" to send", + interactive=True) + output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state]) + audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True) + with gr.Row(): + audio = gr.Audio(source="microphone") + audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state]) + transcribe_button = gr.Button(value="Transcribe") + transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state]) diff --git a/loras/place-your-loras-here.txt b/loras/place-your-loras-here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index c2723490..662182e7 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -61,7 +61,7 @@ def load_quantized(model_name): max_memory[i] = f"{shared.args.gpu_memory[i]}GiB" max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB" - device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"]) + device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"]) model = accelerate.dispatch_model(model, device_map=device_map) # Single GPU diff --git a/modules/LoRA.py b/modules/LoRA.py new file mode 100644 index 00000000..6915e157 --- /dev/null +++ b/modules/LoRA.py @@ -0,0 +1,22 @@ +from pathlib import Path + +import modules.shared as shared +from modules.models import load_model + + +def add_lora_to_model(lora_name): + + from peft import PeftModel + + # Is there a more efficient way of returning to the base model? + if lora_name == "None": + print("Reloading the model to remove the LoRA...") + shared.model, shared.tokenizer = load_model(shared.model_name) + else: + # Why doesn't this work in 16-bit mode? + print(f"Adding the LoRA {lora_name} to the model...") + + params = {} + params['device_map'] = {'': 0} + #params['dtype'] = shared.model.dtype + shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) diff --git a/modules/callbacks.py b/modules/callbacks.py index faa4a5e9..12a90cc3 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -7,6 +7,7 @@ import transformers import modules.shared as shared + # Copied from https://github.com/PygmalionAI/gradio-ui/ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): diff --git a/modules/chat.py b/modules/chat.py index bd45b879..36265990 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -11,17 +11,11 @@ from PIL import Image import modules.extensions as extensions_module import modules.shared as shared from modules.extensions import apply_extensions -from modules.html_generator import generate_chat_html -from modules.text_generation import encode, generate_reply, get_max_prompt_length +from modules.html_generator import fix_newlines, generate_chat_html +from modules.text_generation import (encode, generate_reply, + get_max_prompt_length) -# This gets the new line characters right. -def clean_chat_message(text): - text = text.replace('\n', '\n\n') - text = re.sub(r"\n{3,}", "\n\n", text) - text = text.strip() - return text - def generate_chat_output(history, name1, name2, character): if shared.args.cai_chat: return generate_chat_html(history, name1, name2, character) @@ -29,7 +23,7 @@ def generate_chat_output(history, name1, name2, character): return history def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False): - user_input = clean_chat_message(user_input) + user_input = fix_newlines(user_input) rows = [f"{context.strip()}\n"] if shared.soft_prompt: @@ -82,7 +76,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate if idx != -1: reply = reply[:idx] next_character_found = True - reply = clean_chat_message(reply) + reply = fix_newlines(reply) # If something like "\nYo" is generated just before "\nYou:" # is completed, trim it @@ -97,7 +91,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate def stop_everything_event(): shared.stop_everything = True -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): shared.stop_everything = False just_started = True eos_token = '\n' if check else None @@ -133,7 +127,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical # Generate reply = '' for i in range(chat_generation_attempts): - for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): + for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): # Extracting the reply reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check) @@ -160,7 +154,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical yield shared.history['visible'] -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): +def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): eos_token = '\n' if check else None if 'pygmalion' in shared.model_name.lower(): @@ -172,18 +166,18 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ # Yield *Is typing...* yield shared.processing_message for i in range(chat_generation_attempts): - for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): + for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True) yield reply if next_character_found: break yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): - for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): +def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): + for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): yield generate_chat_html(_history, name1, name2, shared.character) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): +def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) else: @@ -191,7 +185,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character) - for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): + for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] else: diff --git a/modules/extensions.py b/modules/extensions.py index c8de8a7b..836fbc60 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,3 +1,5 @@ +import gradio as gr + import extensions import modules.shared as shared @@ -9,9 +11,12 @@ def load_extensions(): for i, name in enumerate(shared.args.extensions): if name in available_extensions: print(f'Loading the extension "{name}"... ', end='') - exec(f"import extensions.{name}.script") - state[name] = [True, i] - print('Ok.') + try: + exec(f"import extensions.{name}.script") + state[name] = [True, i] + print('Ok.') + except: + print('Fail.') # This iterator returns the extensions in the order specified in the command-line def iterator(): @@ -40,6 +45,9 @@ def create_extensions_block(): extension.params[param] = shared.settings[_id] # Creating the extension ui elements - for extension, name in iterator(): - if hasattr(extension, "ui"): - extension.ui() + if len(state) > 0: + with gr.Box(elem_id="extensions"): + gr.Markdown("Extensions") + for extension, name in iterator(): + if hasattr(extension, "ui"): + extension.ui() diff --git a/modules/html_generator.py b/modules/html_generator.py index 162040ba..940d5486 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -1,6 +1,6 @@ ''' -This is a library for formatting GPT-4chan and chat outputs as nice HTML. +This is a library for formatting text outputs as nice HTML. ''' @@ -8,30 +8,39 @@ import os import re from pathlib import Path +import markdown from PIL import Image # This is to store the paths to the thumbnails of the profile pictures image_cache = {} -def generate_basic_html(s): - css = """ - .container { - max-width: 600px; - margin-left: auto; - margin-right: auto; - background-color: rgb(31, 41, 55); - padding:3em; - } - .container p { - font-size: 16px !important; - color: white !important; - margin-bottom: 22px; - line-height: 1.4 !important; - } - """ - s = '\n'.join([f'

{line}

' for line in s.split('\n')]) - s = f'
{s}
' - return s +with open(Path(__file__).resolve().parent / '../css/html_readable_style.css', 'r') as f: + readable_css = f.read() +with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') as css_f: + _4chan_css = css_f.read() +with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f: + cai_css = f.read() + +def fix_newlines(string): + string = string.replace('\n', '\n\n') + string = re.sub(r"\n{3,}", "\n\n", string) + string = string.strip() + return string + +# This could probably be generalized and improved +def convert_to_markdown(string): + string = string.replace('\\begin{code}', '```') + string = string.replace('\\end{code}', '```') + string = string.replace('\\begin{blockquote}', '> ') + string = string.replace('\\end{blockquote}', '') + string = re.sub(r"(.)```", r"\1\n```", string) +# string = fix_newlines(string) + return markdown.markdown(string, extensions=['fenced_code']) + +def generate_basic_html(string): + string = convert_to_markdown(string) + string = f'
{string}
' + return string def process_post(post, c): t = post.split('\n') @@ -48,113 +57,6 @@ def process_post(post, c): return src def generate_4chan_html(f): - css = """ - - #parent #container { - background-color: #eef2ff; - padding: 17px; - } - #parent #container .reply { - background-color: rgb(214, 218, 240); - border-bottom-color: rgb(183, 197, 217); - border-bottom-style: solid; - border-bottom-width: 1px; - border-image-outset: 0; - border-image-repeat: stretch; - border-image-slice: 100%; - border-image-source: none; - border-image-width: 1; - border-left-color: rgb(0, 0, 0); - border-left-style: none; - border-left-width: 0px; - border-right-color: rgb(183, 197, 217); - border-right-style: solid; - border-right-width: 1px; - border-top-color: rgb(0, 0, 0); - border-top-style: none; - border-top-width: 0px; - color: rgb(0, 0, 0); - display: table; - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - margin-bottom: 4px; - margin-left: 0px; - margin-right: 0px; - margin-top: 4px; - overflow-x: hidden; - overflow-y: hidden; - padding-bottom: 4px; - padding-left: 2px; - padding-right: 2px; - padding-top: 4px; - } - - #parent #container .number { - color: rgb(0, 0, 0); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - width: 342.65px; - margin-right: 7px; - } - - #parent #container .op { - color: rgb(0, 0, 0); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - margin-bottom: 8px; - margin-left: 0px; - margin-right: 0px; - margin-top: 4px; - overflow-x: hidden; - overflow-y: hidden; - } - - #parent #container .op blockquote { - margin-left: 0px !important; - } - - #parent #container .name { - color: rgb(17, 119, 67); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - font-weight: 700; - margin-left: 7px; - } - - #parent #container .quote { - color: rgb(221, 0, 0); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - text-decoration-color: rgb(221, 0, 0); - text-decoration-line: underline; - text-decoration-style: solid; - text-decoration-thickness: auto; - } - - #parent #container .greentext { - color: rgb(120, 153, 34); - font-family: arial, helvetica, sans-serif; - font-size: 13.3333px; - } - - #parent #container blockquote { - margin: 0px !important; - margin-block-start: 1em; - margin-block-end: 1em; - margin-inline-start: 40px; - margin-inline-end: 40px; - margin-top: 13.33px !important; - margin-bottom: 13.33px !important; - margin-left: 40px !important; - margin-right: 40px !important; - } - - #parent #container .message { - color: black; - border: none; - } - """ - posts = [] post = '' c = -2 @@ -181,7 +83,7 @@ def generate_4chan_html(f): posts[i] = f'
{posts[i]}
\n' output = '' - output += f'
' + output += f'
' for post in posts: output += post output += '
' @@ -208,135 +110,39 @@ def get_image_cache(path): return image_cache[path][1] +def load_html_image(paths): + for str_path in paths: + path = Path(str_path) + if path.exists(): + return f'' + return '' + def generate_chat_html(history, name1, name2, character): - css = """ - .chat { - margin-left: auto; - margin-right: auto; - max-width: 800px; - height: 66.67vh; - overflow-y: auto; - padding-right: 20px; - display: flex; - flex-direction: column-reverse; - } - - .message { - display: grid; - grid-template-columns: 60px 1fr; - padding-bottom: 25px; - font-size: 15px; - font-family: Helvetica, Arial, sans-serif; - line-height: 1.428571429; - } - - .circle-you { - width: 50px; - height: 50px; - background-color: rgb(238, 78, 59); - border-radius: 50%; - } - - .circle-bot { - width: 50px; - height: 50px; - background-color: rgb(59, 78, 244); - border-radius: 50%; - } - - .circle-bot img, .circle-you img { - border-radius: 50%; - width: 100%; - height: 100%; - object-fit: cover; - } - - .text { - } - - .text p { - margin-top: 5px; - } - - .username { - font-weight: bold; - } - - .message-body { - } - - .message-body img { - max-width: 300px; - max-height: 300px; - border-radius: 20px; - } - - .message-body p { - margin-bottom: 0 !important; - font-size: 15px !important; - line-height: 1.428571429 !important; - } - - .dark .message-body p em { - color: rgb(138, 138, 138) !important; - } - - .message-body p em { - color: rgb(110, 110, 110) !important; - } - - """ - - output = '' - output += f'
' - img = '' - - for i in [ - f"characters/{character}.png", - f"characters/{character}.jpg", - f"characters/{character}.jpeg", - "img_bot.png", - "img_bot.jpg", - "img_bot.jpeg" - ]: - - path = Path(i) - if path.exists(): - img = f'' - break - - img_me = '' - for i in ["img_me.png", "img_me.jpg", "img_me.jpeg"]: - path = Path(i) - if path.exists(): - img_me = f'' - break + output = f'
' + + img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"]) + img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"]) for i,_row in enumerate(history[::-1]): - row = _row.copy() - row[0] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"\2", row[0]) - row[1] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"\2", row[1]) - row[0] = re.sub(r"(\*)([^\*\n]*)(\*)", r"\2", row[0]) - row[1] = re.sub(r"(\*)([^\*\n]*)(\*)", r"\2", row[1]) - p = '\n'.join([f"

{x}

" for x in row[1].split('\n')]) + row = [convert_to_markdown(entry) for entry in _row] + output += f"""
- {img} + {img_bot}
{name2}
- {p} + {row[1]}
""" if not (i == len(history)-1 and len(row[0]) == 0): - p = '\n'.join([f"

{x}

" for x in row[0].split('\n')]) output += f"""
@@ -347,7 +153,7 @@ def generate_chat_html(history, name1, name2, character): {name1}
- {p} + {row[0]}
diff --git a/modules/models.py b/modules/models.py index f4bb11fd..f07e738b 100644 --- a/modules/models.py +++ b/modules/models.py @@ -7,7 +7,9 @@ from pathlib import Path import numpy as np import torch import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer +from accelerate import infer_auto_device_map, init_empty_weights +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) import modules.shared as shared @@ -16,8 +18,7 @@ transformers.logging.set_verbosity_error() local_rank = None if shared.args.flexgen: - from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM, - Policy, str2bool) + from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy if shared.args.deepspeed: import deepspeed @@ -46,7 +47,12 @@ def load_model(model_name): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16).cuda() + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + if torch.has_mps: + device = torch.device('mps') + model = model.to(device) + else: + model = model.cuda() # FlexGen elif shared.args.flexgen: @@ -95,39 +101,60 @@ def load_model(model_name): # Custom else: - command = "AutoModelForCausalLM.from_pretrained" - params = ["low_cpu_mem_usage=True"] - if not shared.args.cpu and not torch.cuda.is_available(): - print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n") + params = {"low_cpu_mem_usage": True} + if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): + print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n") shared.args.cpu = True if shared.args.cpu: - params.append("low_cpu_mem_usage=True") - params.append("torch_dtype=torch.float32") + params["torch_dtype"] = torch.float32 else: - params.append("device_map='auto'") - params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16") + params["device_map"] = 'auto' + if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)): + params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) + elif shared.args.load_in_8bit: + params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) + elif shared.args.bf16: + params["torch_dtype"] = torch.bfloat16 + else: + params["torch_dtype"] = torch.float16 if shared.args.gpu_memory: memory_map = shared.args.gpu_memory - max_memory = f"max_memory={{0: '{memory_map[0]}GiB'" - for i in range(1, len(memory_map)): - max_memory += (f", {i}: '{memory_map[i]}GiB'") - max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") - params.append(max_memory) - elif not shared.args.load_in_8bit: - total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) - suggestion = round((total_mem-1000)/1000)*1000 - if total_mem-suggestion < 800: + max_memory = {} + for i in range(len(memory_map)): + max_memory[i] = f'{memory_map[i]}GiB' + max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB' + params['max_memory'] = max_memory + elif shared.args.auto_devices: + total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024)) + suggestion = round((total_mem-1000) / 1000) * 1000 + if total_mem - suggestion < 800: suggestion -= 1000 suggestion = int(round(suggestion/1000)) print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m") - params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}") - if shared.args.disk: - params.append(f"offload_folder='{shared.args.disk_cache_dir}'") + + max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'} + params['max_memory'] = max_memory - command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})" - model = eval(command) + if shared.args.disk: + params["offload_folder"] = shared.args.disk_cache_dir + + checkpoint = Path(f'models/{shared.model_name}') + + if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': + config = AutoConfig.from_pretrained(checkpoint) + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config) + model.tie_weights() + params['device_map'] = infer_auto_device_map( + model, + dtype=torch.int8, + max_memory=params['max_memory'], + no_split_module_classes = model._no_split_modules + ) + + model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) # Loading the tokenizer if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists(): diff --git a/modules/shared.py b/modules/shared.py index ea2eb50b..e3920f22 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -2,7 +2,8 @@ import argparse model = None tokenizer = None -model_name = "" +model_name = "None" +lora_name = "None" soft_prompt_tensor = None soft_prompt = False is_RWKV = False @@ -19,6 +20,9 @@ gradio = {} # Generation input parameters input_params = [] +# For restarting the interface +need_restart = False + settings = { 'max_new_tokens': 200, 'max_new_tokens_min': 1, @@ -26,7 +30,7 @@ settings = { 'name1': 'Person 1', 'name2': 'Person 2', 'context': 'This is a conversation between two people.', - 'stop_at_newline': True, + 'stop_at_newline': False, 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, 'chat_prompt_size_max': 2048, @@ -49,6 +53,10 @@ settings = { '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n', '(rosey|chip|joi)_.*_instruct.*': 'User: \n', 'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>' + }, + 'lora_prompts': { + 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', + '(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" } } @@ -64,6 +72,7 @@ def str2bool(v): parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54)) parser.add_argument('--model', type=str, help='Name of the model to load by default.') +parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') diff --git a/modules/text_generation.py b/modules/text_generation.py index 70a51d91..1d11de12 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -33,12 +33,15 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): return input_ids.numpy() elif shared.args.deepspeed: return input_ids.to(device=local_rank) + elif torch.has_mps: + device = torch.device('mps') + return input_ids.to(device) else: return input_ids.cuda() def decode(output_ids): # Open Assistant relies on special tokens like <|endoftext|> - if re.match('oasst-*', shared.model_name.lower()): + if re.match('(oasst|galactica)-*', shared.model_name.lower()): return shared.tokenizer.decode(output_ids, skip_special_tokens=False) else: reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) @@ -89,7 +92,7 @@ def clear_torch_cache(): if not shared.args.cpu: torch.cuda.empty_cache() -def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): +def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None): clear_torch_cache() t0 = time.time() @@ -101,7 +104,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) yield formatted_outputs(reply, shared.model_name) else: - yield formatted_outputs(question, shared.model_name) + if not (shared.args.chat or shared.args.cai_chat): + yield formatted_outputs(question, shared.model_name) # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): @@ -143,6 +147,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi "top_p": top_p, "typical_p": typical_p, "repetition_penalty": repetition_penalty, + "encoder_repetition_penalty": encoder_repetition_penalty, "top_k": top_k, "min_length": min_length if shared.args.no_stream else 0, "no_repeat_ngram_size": no_repeat_ngram_size, @@ -196,7 +201,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) - yield formatted_outputs(original_question, shared.model_name) + if not (shared.args.chat or shared.args.cai_chat): + yield formatted_outputs(original_question, shared.model_name) with generate_with_streaming(**generate_params) as generator: for output in generator: if shared.soft_prompt: diff --git a/modules/ui.py b/modules/ui.py index bb193e35..80bd7c1c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,68 +1,17 @@ +from pathlib import Path + import gradio as gr refresh_symbol = '\U0001f504' # 🔄 -css = """ -.tabs.svelte-710i53 { - margin-top: 0 -} -.py-6 { - padding-top: 2.5rem -} -.dark #refresh-button { - background-color: #ffffff1f; -} -#refresh-button { - flex: none; - margin: 0; - padding: 0; - min-width: 50px; - border: none; - box-shadow: none; - border-radius: 10px; - background-color: #0000000d; -} -#download-label, #upload-label { - min-height: 0 -} -#accordion { -} -.dark svg { - fill: white; -} -svg { - display: unset !important; - vertical-align: middle !important; - margin: 5px; -} -ol li p, ul li p { - display: inline-block; -} -""" - -chat_css = """ -.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx { - height: 66.67vh -} -.gradio-container { - max-width: 800px !important; - margin-left: auto !important; - margin-right: auto !important; -} -.w-screen { - width: unset -} -div.svelte-362y77>*, div.svelte-362y77>.form>* { - flex-wrap: nowrap -} -/* fixes the API documentation in chat mode */ -.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h { - display: grid; -} -.pending.svelte-1ed2p3z { - opacity: 1; -} -""" +with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f: + css = f.read() +with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f: + chat_css = f.read() +with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f: + main_js = f.read() +with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: + chat_js = f.read() class ToolButton(gr.Button, gr.components.FormComponent): """Small button with single emoji as text, fits inside gradio forms""" diff --git a/presets/Default.txt b/presets/Default.txt index 9f0983ec..d5283836 100644 --- a/presets/Default.txt +++ b/presets/Default.txt @@ -1,12 +1,7 @@ do_sample=True -temperature=1 -top_p=1 -typical_p=1 -repetition_penalty=1 -top_k=50 -num_beams=1 -penalty_alpha=0 -min_length=0 -length_penalty=1 -no_repeat_ngram_size=0 +top_p=0.5 +top_k=40 +temperature=0.7 +repetition_penalty=1.2 +typical_p=1.0 early_stopping=False diff --git a/presets/Individual Today.txt b/presets/Individual Today.txt deleted file mode 100644 index f40b879c..00000000 --- a/presets/Individual Today.txt +++ /dev/null @@ -1,6 +0,0 @@ -do_sample=True -top_p=0.9 -top_k=50 -temperature=1.39 -repetition_penalty=1.08 -typical_p=0.2 diff --git a/requirements.txt b/requirements.txt index 9bb2b74f..e5b3de69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,12 @@ accelerate==0.17.1 bitsandbytes==0.37.1 flexgen==0.1.7 gradio==3.18.0 +markdown numpy +peft==0.2.0 requests -rwkv==0.4.2 +rwkv==0.7.0 safetensors==0.3.0 sentencepiece tqdm -git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176 +git+https://github.com/huggingface/transformers diff --git a/server.py b/server.py index a54e3b62..060f09d5 100644 --- a/server.py +++ b/server.py @@ -15,6 +15,7 @@ import modules.extensions as extensions_module import modules.shared as shared import modules.ui as ui from modules.html_generator import generate_chat_html +from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt from modules.text_generation import generate_reply @@ -34,7 +35,7 @@ def get_available_models(): if shared.args.flexgen: return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower) + return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) @@ -48,6 +49,9 @@ def get_available_extensions(): def get_available_softprompts(): return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) +def get_available_loras(): + return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model @@ -59,6 +63,17 @@ def load_model_wrapper(selected_model): return selected_model +def load_lora_wrapper(selected_lora): + shared.lora_name = selected_lora + default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] + + if not shared.args.cpu: + gc.collect() + torch.cuda.empty_cache() + add_lora_to_model(selected_lora) + + return selected_lora, default_text + def load_preset_values(preset_menu, return_dict=False): generate_params = { 'do_sample': True, @@ -66,6 +81,7 @@ def load_preset_values(preset_menu, return_dict=False): 'top_p': 1, 'typical_p': 1, 'repetition_penalty': 1, + 'encoder_repetition_penalty': 1, 'top_k': 50, 'num_beams': 1, 'penalty_alpha': 0, @@ -86,7 +102,7 @@ def load_preset_values(preset_menu, return_dict=False): if return_dict: return generate_params else: - return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] + return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] def upload_soft_prompt(file): with zipfile.ZipFile(io.BytesIO(file)) as zf: @@ -100,9 +116,7 @@ def upload_soft_prompt(file): return name -def create_settings_menus(default_preset): - generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) - +def create_model_and_preset_menus(): with gr.Row(): with gr.Column(): with gr.Row(): @@ -113,31 +127,48 @@ def create_settings_menus(default_preset): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') - with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'): - with gr.Row(): - with gr.Column(): - shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') - shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty') - shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k') - shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p') - with gr.Column(): +def create_settings_menus(default_preset): + generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) + + with gr.Row(): + shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') + ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') + + with gr.Row(): + with gr.Column(): + with gr.Box(): + gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))') + with gr.Row(): + with gr.Column(): + shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature') + shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p') + shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k') + shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p') + with gr.Column(): + shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty') + shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty') + shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') + shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream) shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample') - shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p') - shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size') - shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream) + with gr.Column(): + with gr.Box(): + gr.Markdown('Contrastive search') + shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') - gr.Markdown('Contrastive search:') - shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') + with gr.Box(): + gr.Markdown('Beam search (uses a lot of VRAM)') + with gr.Row(): + with gr.Column(): + shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') + with gr.Column(): + shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') + shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - gr.Markdown('Beam search (uses a lot of VRAM):') - with gr.Row(): - with gr.Column(): - shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams') - with gr.Column(): - shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') - shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') + with gr.Row(): + shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') + ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') - with gr.Accordion('Soft prompt', open=False, elem_id='accordion'): + with gr.Accordion('Soft prompt', open=False): with gr.Row(): shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button') @@ -147,14 +178,35 @@ def create_settings_menus(default_preset): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) - shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']]) + shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) + shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) + shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']]) +def set_interface_arguments(interface_mode, extensions, cmd_active): + modes = ["default", "notebook", "chat", "cai_chat"] + cmd_list = vars(shared.args) + cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] + + shared.args.extensions = extensions + for k in modes[1:]: + exec(f"shared.args.{k} = False") + if interface_mode != "default": + exec(f"shared.args.{interface_mode} = True") + + for k in cmd_list: + exec(f"shared.args.{k} = False") + for k in cmd_active: + exec(f"shared.args.{k} = True") + + shared.need_restart = True + available_models = get_available_models() available_presets = get_available_presets() available_characters = get_available_characters() available_softprompts = get_available_softprompts() +available_loras = get_available_loras() # Default extensions extensions_module.available_extensions = get_available_extensions() @@ -168,8 +220,6 @@ else: shared.args.extensions = shared.args.extensions or [] if extension not in shared.args.extensions: shared.args.extensions.append(extension) -if shared.args.extensions is not None and len(shared.args.extensions) > 0: - extensions_module.load_extensions() # Default model if shared.args.model is not None: @@ -189,191 +239,235 @@ else: print() shared.model_name = available_models[i] shared.model, shared.tokenizer = load_model(shared.model_name) +if shared.args.lora: + print(shared.args.lora) + shared.lora_name = shared.args.lora + add_lora_to_model(shared.lora_name) # Default UI settings -gen_events = [] default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')] -default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] +default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] +if default_text == '': + default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] title ='Text generation web UI' description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n' suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' -if shared.args.chat or shared.args.cai_chat: - with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: - if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) - else: - shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) - shared.gradio['textbox'] = gr.Textbox(label='Input') - with gr.Row(): - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['Generate'] = gr.Button('Generate') - with gr.Row(): - shared.gradio['Impersonate'] = gr.Button('Impersonate') - shared.gradio['Regenerate'] = gr.Button('Regenerate') - with gr.Row(): - shared.gradio['Copy last reply'] = gr.Button('Copy last reply') - shared.gradio['Replace last reply'] = gr.Button('Replace last reply') - shared.gradio['Remove last'] = gr.Button('Remove last') +def create_interface(): - shared.gradio['Clear history'] = gr.Button('Clear history') - shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) - shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) - with gr.Tab('Chat settings'): - shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') - shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') - shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context') - with gr.Row(): - shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') - ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') + gen_events = [] + if shared.args.extensions is not None and len(shared.args.extensions) > 0: + extensions_module.load_extensions() - with gr.Row(): - shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') - with gr.Row(): - with gr.Tab('Chat history'): - with gr.Row(): - with gr.Column(): - gr.Markdown('Upload') - shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt']) - with gr.Column(): - gr.Markdown('Download') - shared.gradio['download'] = gr.File() - shared.gradio['download_button'] = gr.Button(value='Click me') - with gr.Tab('Upload character'): - with gr.Row(): - with gr.Column(): - gr.Markdown('1. Select the JSON file') - shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json']) - with gr.Column(): - gr.Markdown('2. Select your character\'s profile picture (optional)') - shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image']) - shared.gradio['Upload character'] = gr.Button(value='Submit') - with gr.Tab('Upload your profile picture'): - shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image']) - with gr.Tab('Upload TavernAI Character Card'): - shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) - - with gr.Tab('Generation settings'): - with gr.Row(): - with gr.Column(): - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - with gr.Column(): - shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) - shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') - create_settings_menus(default_preset) - - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] - if shared.args.extensions is not None: - with gr.Tab('Extensions'): - extensions_module.create_extensions_block() - - function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - - gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) - - shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) - shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) - - # Clear history with confirmation - clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] - shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) - shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) - shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) - shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) - - shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) - shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) - shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) - - # Clearing stuff and saving the history - for i in ['Generate', 'Regenerate', 'Replace last reply']: - shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) - shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) - shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - - shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']]) - shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) - shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) - shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], []) - - reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] - reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] - shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) - shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) - shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) - - shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) - shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) - -elif shared.args.notebook: - with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']: - gr.Markdown(description) - with gr.Tab('Raw'): - shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23) - with gr.Tab('Markdown'): - shared.gradio['markdown'] = gr.Markdown() - with gr.Tab('HTML'): - shared.gradio['html'] = gr.HTML() - - shared.gradio['Generate'] = gr.Button('Generate') - shared.gradio['Stop'] = gr.Button('Stop') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - - create_settings_menus(default_preset) - if shared.args.extensions is not None: - extensions_module.create_extensions_block() - - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] - output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] - gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) - gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(None, None, None, cancels=gen_events) - -else: - with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']: - gr.Markdown(description) - with gr.Row(): - with gr.Column(): - shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input') - shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) - shared.gradio['Generate'] = gr.Button('Generate') + with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: + if shared.args.chat or shared.args.cai_chat: + with gr.Tab("Text generation", elem_id="main"): + if shared.args.cai_chat: + shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) + else: + shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) + shared.gradio['textbox'] = gr.Textbox(label='Input') with gr.Row(): - with gr.Column(): - shared.gradio['Continue'] = gr.Button('Continue') - with gr.Column(): - shared.gradio['Stop'] = gr.Button('Stop') + shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") + shared.gradio['Generate'] = gr.Button('Generate') + with gr.Row(): + shared.gradio['Impersonate'] = gr.Button('Impersonate') + shared.gradio['Regenerate'] = gr.Button('Regenerate') + with gr.Row(): + shared.gradio['Copy last reply'] = gr.Button('Copy last reply') + shared.gradio['Replace last reply'] = gr.Button('Replace last reply') + shared.gradio['Remove last'] = gr.Button('Remove last') + + shared.gradio['Clear history'] = gr.Button('Clear history') + shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) + shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) + + create_model_and_preset_menus() + + with gr.Tab("Character", elem_id="chat-settings"): + shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') + shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') + shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context') + with gr.Row(): + shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') + ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') + + with gr.Row(): + with gr.Tab('Chat history'): + with gr.Row(): + with gr.Column(): + gr.Markdown('Upload') + shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt']) + with gr.Column(): + gr.Markdown('Download') + shared.gradio['download'] = gr.File() + shared.gradio['download_button'] = gr.Button(value='Click me') + with gr.Tab('Upload character'): + with gr.Row(): + with gr.Column(): + gr.Markdown('1. Select the JSON file') + shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json']) + with gr.Column(): + gr.Markdown('2. Select your character\'s profile picture (optional)') + shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image']) + shared.gradio['Upload character'] = gr.Button(value='Submit') + with gr.Tab('Upload your profile picture'): + shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image']) + with gr.Tab('Upload TavernAI Character Card'): + shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) + + with gr.Tab("Parameters", elem_id="parameters"): + with gr.Box(): + gr.Markdown("Chat parameters") + with gr.Row(): + with gr.Column(): + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) + with gr.Column(): + shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') + shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') create_settings_menus(default_preset) - if shared.args.extensions is not None: - extensions_module.create_extensions_block() - with gr.Column(): + function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] + + gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) + shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) + + shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) + shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) + + # Clear history with confirmation + clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] + shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) + shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) + shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) + shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) + + shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) + shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) + shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) + + # Clearing stuff and saving the history + for i in ['Generate', 'Regenerate', 'Replace last reply']: + shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) + shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) + shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) + + shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']]) + shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) + shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) + shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], []) + + reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] + reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] + shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) + + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") + shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) + shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) + + elif shared.args.notebook: + with gr.Tab("Text generation", elem_id="main"): with gr.Tab('Raw'): - shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output') + shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25) with gr.Tab('Markdown'): shared.gradio['markdown'] = gr.Markdown() with gr.Tab('HTML'): shared.gradio['html'] = gr.HTML() - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] - output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] - gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) - gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + with gr.Row(): + shared.gradio['Stop'] = gr.Button('Stop') + shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) -shared.gradio['interface'].queue() -if shared.args.listen: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) -else: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + create_model_and_preset_menus() + with gr.Tab("Parameters", elem_id="parameters"): + create_settings_menus(default_preset) + + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) + shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") + + else: + with gr.Tab("Text generation", elem_id="main"): + with gr.Row(): + with gr.Column(): + shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input') + shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) + shared.gradio['Generate'] = gr.Button('Generate') + with gr.Row(): + with gr.Column(): + shared.gradio['Continue'] = gr.Button('Continue') + with gr.Column(): + shared.gradio['Stop'] = gr.Button('Stop') + + create_model_and_preset_menus() + + with gr.Column(): + with gr.Tab('Raw'): + shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output') + with gr.Tab('Markdown'): + shared.gradio['markdown'] = gr.Markdown() + with gr.Tab('HTML'): + shared.gradio['html'] = gr.HTML() + with gr.Tab("Parameters", elem_id="parameters"): + create_settings_menus(default_preset) + + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] + output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) + shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") + + with gr.Tab("Interface mode", elem_id="interface-mode"): + modes = ["default", "notebook", "chat", "cai_chat"] + current_mode = "default" + for mode in modes[1:]: + if eval(f"shared.args.{mode}"): + current_mode = mode + break + cmd_list = vars(shared.args) + cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] + active_cmd_list = [k for k in cmd_list if vars(shared.args)[k]] + + gr.Markdown("*Experimental*") + shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode") + shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions") + shared.gradio['cmd_arguments_menu'] = gr.CheckboxGroup(choices=cmd_list, value=active_cmd_list, label="Boolean command-line flags") + shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary") + + shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None) + shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500)}') + + if shared.args.extensions is not None: + extensions_module.create_extensions_block() + + # Launch the interface + shared.gradio['interface'].queue() + if shared.args.listen: + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + else: + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + +create_interface() -# I think that I will need this later while True: time.sleep(0.5) + if shared.need_restart: + shared.need_restart = False + shared.gradio['interface'].close() + create_interface() diff --git a/settings-template.json b/settings-template.json index 9da43970..7a7de7af 100644 --- a/settings-template.json +++ b/settings-template.json @@ -5,7 +5,7 @@ "name1": "Person 1", "name2": "Person 2", "context": "This is a conversation between two people.", - "stop_at_newline": true, + "stop_at_newline": false, "chat_prompt_size": 2048, "chat_prompt_size_min": 0, "chat_prompt_size_max": 2048, @@ -23,13 +23,16 @@ "presets": { "default": "NovelAI-Sphinx Moth", "pygmalion-*": "Pygmalion", - "RWKV-*": "Naive", - "(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)" + "RWKV-*": "Naive" }, "prompts": { "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n", "(rosey|chip|joi)_.*_instruct.*": "User: \n", "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>" + }, + "lora_prompts": { + "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", + "alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" } }