SD Api Pics extension, v.1.1 (#596)

This commit is contained in:
Φφ 2023-04-08 03:36:04 +03:00 committed by GitHub
parent 5543a5089d
commit ffd102e5c0
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 282 additions and 102 deletions

View file

@ -0,0 +1,78 @@
## Description:
TL;DR: Lets the bot answer you with a picture!
Stable Diffusion API pictures for TextGen, v.1.1.0
An extension to [oobabooga's textgen-webui](https://github.com/oobabooga/text-generation-webui) allowing you to receive pics generated by [Automatic1111's SD-WebUI API](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
<details>
<summary>Interface overview</summary>
![Interface](https://raw.githubusercontent.com/Brawlence/texgen-webui-SD_api_pics/main/illust/Interface.jpg)
</details>
Load it in the `--chat` mode with `--extension sd_api_pictures` alongside `send_pictures` (it's not really required, but completes the picture, *pun intended*).
The image generation is triggered either:
- manually through the 'Force the picture response' button while in `Manual` or `Immersive/Interactive` modes OR
- automatically in `Immersive/Interactive` mode if the words `'send|main|message|me'` are followed by `'image|pic|picture|photo|snap|snapshot|selfie|meme'` in the user's prompt
- always on in Picturebook/Adventure mode (if not currently suppressed by 'Suppress the picture response')
## Prerequisites
One needs an available instance of Automatic1111's webui running with an `--api` flag. Ain't tested with a notebook / cloud hosted one but should be possible.
To run it locally in parallel on the same machine, specify custom `--listen-port` for either Auto1111's or ooba's webUIs.
## Features:
- API detection (press enter in the API box)
- VRAM management (model shuffling)
- Three different operation modes (manual, interactive, always-on)
- persistent settings via settings.json
The model input is modified only in the interactive mode; other two are unaffected. The output pic description is presented differently for Picture-book / Adventure mode.
Connection check (insert the Auto1111's address and press Enter):
![API-check](https://raw.githubusercontent.com/Brawlence/texgen-webui-SD_api_pics/main/illust/API-check.gif)
### Persistents settings
Create or modify the `settings.json` in the `text-generation-webui` root directory to override the defaults
present in script.py, ex:
```json
{
"sd_api_pictures-manage_VRAM": 1,
"sd_api_pictures-save_img": 1,
"sd_api_pictures-prompt_prefix": "(Masterpiece:1.1), detailed, intricate, colorful, (solo:1.1)",
"sd_api_pictures-sampler_name": "DPM++ 2M Karras"
}
```
will automatically set the `Manage VRAM` & `Keep original images` checkboxes and change the texts in `Prompt Prefix` and `Sampler name` on load.
---
## Demonstrations:
Those are examples of the version 1.0.0, but the core functionality is still the same
<details>
<summary>Conversation 1</summary>
![EXA1](https://user-images.githubusercontent.com/42910943/224866564-939a3bcb-e7cf-4ac0-a33f-b3047b55054d.jpg)
![EXA2](https://user-images.githubusercontent.com/42910943/224866566-38394054-1320-45cf-9515-afa76d9d7745.jpg)
![EXA3](https://user-images.githubusercontent.com/42910943/224866568-10ea47b7-0bac-4269-9ec9-22c387a13b59.jpg)
![EXA4](https://user-images.githubusercontent.com/42910943/224866569-326121ad-1ea1-4874-9f6b-4bca7930a263.jpg)
</details>
<details>
<summary>Conversation 2</summary>
![Hist1](https://user-images.githubusercontent.com/42910943/224865517-c6966b58-bc4d-4353-aab9-6eb97778d7bf.jpg)
![Hist2](https://user-images.githubusercontent.com/42910943/224865527-b2fe7c2e-0da5-4c2e-b705-42e233b07084.jpg)
![Hist3](https://user-images.githubusercontent.com/42910943/224865535-a38d94e7-8975-4a46-a655-1ae1de41f85d.jpg)
</details>

View file

@ -1,34 +1,78 @@
import base64 import base64
import io import io
import re import re
import time
from datetime import date
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.shared as shared
import requests import requests
import torch import torch
from modules.models import reload_model, unload_model
from PIL import Image from PIL import Image
from modules import chat, shared
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui # parameters which can be customized in settings.json of webui
params = { params = {
'enable_SD_api': False,
'address': 'http://127.0.0.1:7860', 'address': 'http://127.0.0.1:7860',
'mode': 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
'manage_VRAM': False,
'save_img': False, 'save_img': False,
'SD_model': 'NeverEndingDream', # not really used right now 'SD_model': 'NeverEndingDream', # not used right now
'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful', 'prompt_prefix': '(Masterpiece:1.1), detailed, intricate, colorful',
'negative_prompt': '(worst quality, low quality:1.3)', 'negative_prompt': '(worst quality, low quality:1.3)',
'side_length': 512, 'width': 512,
'restore_faces': False 'height': 512,
'restore_faces': False,
'seed': -1,
'sampler_name': 'DDIM',
'steps': 32,
'cfg_scale': 7
} }
def give_VRAM_priority(actor):
global shared, params
if actor == 'SD':
unload_model()
print("Requesting Auto1111 to re-load last checkpoint used...")
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
response.raise_for_status()
elif actor == 'LLM':
print("Requesting Auto1111 to vacate VRAM...")
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
response.raise_for_status()
reload_model()
elif actor == 'set':
print("VRAM mangement activated -- requesting Auto1111 to vacate VRAM...")
response = requests.post(url=f'{params["address"]}/sdapi/v1/unload-checkpoint', json='')
response.raise_for_status()
elif actor == 'reset':
print("VRAM mangement deactivated -- requesting Auto1111 to reload checkpoint")
response = requests.post(url=f'{params["address"]}/sdapi/v1/reload-checkpoint', json='')
response.raise_for_status()
else:
raise RuntimeError(f'Managing VRAM: "{actor}" is not a known state!')
response.raise_for_status()
del response
if params['manage_VRAM']:
give_VRAM_priority('set')
samplers = ['DDIM', 'DPM++ 2M Karras'] # TODO: get the availible samplers with http://{address}}/sdapi/v1/samplers
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
streaming_state = shared.args.no_stream # remember if chat streaming was enabled streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture picture_response = False # specifies if the next model response should appear as a picture
pic_id = 0
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
@ -36,7 +80,13 @@ def remove_surrounded_chars(string):
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)', '', string) return re.sub('\*[^\*]*?(\*|$)', '', string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def triggers_are_in(string):
string = remove_surrounded_chars(string)
# regex searches for send|main|message|me (at the end of the word) followed by
# a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
# (?aims) are regex parser flags
return bool(re.search('(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b', string))
def input_modifier(string): def input_modifier(string):
@ -44,75 +94,80 @@ def input_modifier(string):
This function is applied to your text inputs before This function is applied to your text inputs before
they are fed into the model. they are fed into the model.
""" """
global params, picture_response
if not params['enable_SD_api']: global params
if not params['mode'] == 1: # if not in immersive/interactive mode, do nothing
return string return string
commands = ['send', 'mail', 'me'] if triggers_are_in(string): # if we're in it, check for trigger words
mediums = ['image', 'pic', 'picture', 'photo'] toggle_generation(True)
subjects = ['yourself', 'own'] string = string.lower()
lowstr = string.lower() if "of" in string:
subject = string.split('of', 1)[1] # subdivide the string once by the first 'of' instance and get what's coming after it
# TODO: refactor out to separate handler and also replace detection with a regexp string = "Please provide a detailed and vivid description of " + subject
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found else:
picture_response = True string = "Please provide a detailed description of your appearance, your surroundings and what you are doing right now"
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
shared.processing_message = "*Is sending a picture...*"
string = "Please provide a detailed description of your surroundings, how you look and the situation you're in and what you are doing right now"
if any(target in lowstr for target in subjects): # the focus of the image should be on the sending character
string = "Please provide a detailed and vivid description of how you look and what you are wearing"
return string return string
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description):
global params, pic_id global params
if params['manage_VRAM']:
give_VRAM_priority('SD')
payload = { payload = {
"prompt": params['prompt_prefix'] + description, "prompt": params['prompt_prefix'] + description,
"seed": -1, "seed": params['seed'],
"sampler_name": "DPM++ 2M Karras", "sampler_name": params['sampler_name'],
"steps": 32, "steps": params['steps'],
"cfg_scale": 7, "cfg_scale": params['cfg_scale'],
"width": params['side_length'], "width": params['width'],
"height": params['side_length'], "height": params['height'],
"restore_faces": params['restore_faces'], "restore_faces": params['restore_faces'],
"negative_prompt": params['negative_prompt'] "negative_prompt": params['negative_prompt']
} }
print(f'Prompting the image generator via the API on {params["address"]}...')
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload) response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
response.raise_for_status()
r = response.json() r = response.json()
visible_result = "" visible_result = ""
for img_str in r['images']: for img_str in r['images']:
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0]))) image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
if params['save_img']: if params['save_img']:
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png') variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}'
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
output_file.parent.mkdir(parents=True, exist_ok=True)
image.save(output_file.as_posix()) image.save(output_file.as_posix())
pic_id += 1 visible_result = visible_result + f'<img src="/file/extensions/sd_api_pictures/outputs/{variadic}.png" alt="{description}" style="max-width: unset; max-height: unset;">\n'
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history else:
image.thumbnail((300, 300)) # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
buffered = io.BytesIO() image.thumbnail((300, 300))
image.save(buffered, format="JPEG") buffered = io.BytesIO()
buffered.seek(0) image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue() buffered.seek(0)
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode() image_bytes = buffered.getvalue()
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n' img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
if params['manage_VRAM']:
give_VRAM_priority('LLM')
return visible_result return visible_result
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging? # and replace it with 'text' for the purposes of logging?
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
""" """
global pic_id, picture_response, streaming_state
global picture_response, params
if not picture_response: if not picture_response:
return string return string
@ -125,17 +180,18 @@ def output_modifier(string):
if string == '': if string == '':
string = 'no viable description in reply, try regenerating' string = 'no viable description in reply, try regenerating'
return string
# I can't for the love of all that's holy get the name from shared.gradio['name1'], so for now it will be like this text = ""
text = f'*Description: "{string}"*' if (params['mode'] < 2):
toggle_generation(False)
text = f'*Sends a picture which portrays: “{string}”*'
else:
text = string
image = get_SD_pictures(string) string = get_SD_pictures(string) + "\n" + text
picture_response = False return string
shared.processing_message = "*Is typing...*"
shared.args.no_stream = streaming_state
return image + "\n" + text
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
@ -148,42 +204,91 @@ def bot_prefix_modifier(string):
return string return string
def force_pic(): def toggle_generation(*args):
global picture_response global picture_response, shared, streaming_state
picture_response = True
if not args:
picture_response = not picture_response
else:
picture_response = args[0]
shared.args.no_stream = True if picture_response else streaming_state # Disable streaming cause otherwise the SD-generated picture would return as a dud
shared.processing_message = "*Is sending a picture...*" if picture_response else "*Is typing...*"
def filter_address(address):
address = address.strip()
# address = re.sub('http(s)?:\/\/|\/$','',address) # remove starting http:// OR https:// OR trailing slash
address = re.sub('\/$', '', address) # remove trailing /s
if not address.startswith('http'):
address = 'http://' + address
return address
def SD_api_address_update(address):
global params
msg = "✔️ SD API is found on:"
address = filter_address(address)
params.update({"address": address})
try:
response = requests.get(url=f'{params["address"]}/sdapi/v1/sd-models')
response.raise_for_status()
# r = response.json()
except:
msg = "❌ No SD API endpoint on:"
return gr.Textbox.update(label=msg)
def ui(): def ui():
# Gradio elements # Gradio elements
with gr.Accordion("Stable Diffusion api integration", open=True): # gr.Markdown('### Stable Diffusion API Pictures') # Currently the name of extension is shown as the title
with gr.Accordion("Parameters", open=True):
with gr.Row(): with gr.Row():
with gr.Column(): address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Auto1111\'s WebUI address')
enable = gr.Checkbox(value=params['enable_SD_api'], label='Activate SD Api integration') mode = gr.Dropdown(["Manual", "Immersive/Interactive", "Picturebook/Adventure"], value="Manual", label="Mode of operation", type="index")
save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir') with gr.Column(scale=1, min_width=300):
with gr.Column(): manage_VRAM = gr.Checkbox(value=params['manage_VRAM'], label='Manage VRAM')
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address') save_img = gr.Checkbox(value=params['save_img'], label='Keep original images and use them in chat')
with gr.Row(): force_pic = gr.Button("Force the picture response")
force_btn = gr.Button("Force the next response to be a picture") suppr_pic = gr.Button("Suppress the picture response")
generate_now_btn = gr.Button("Generate an image response to the input")
with gr.Accordion("Generation parameters", open=False): with gr.Accordion("Generation parameters", open=False):
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)') prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
with gr.Row(): with gr.Row():
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt') with gr.Column():
dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions') negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model') sampler_name = gr.Textbox(placeholder=params['sampler_name'], value=params['sampler_name'], label='Sampler')
with gr.Column():
width = gr.Slider(256, 768, value=params['width'], step=64, label='Width')
height = gr.Slider(256, 768, value=params['height'], step=64, label='Height')
with gr.Row():
steps = gr.Number(label="Steps:", value=params['steps'])
seed = gr.Number(label="Seed:", value=params['seed'])
cfg_scale = gr.Number(label="CFG Scale:", value=params['cfg_scale'])
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None) address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
mode.select(lambda x: params.update({"mode": x}), mode, None)
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
manage_VRAM.change(lambda x: params.update({"manage_VRAM": x}), manage_VRAM, None)
manage_VRAM.change(lambda x: give_VRAM_priority('set' if x else 'reset'), inputs=manage_VRAM, outputs=None)
save_img.change(lambda x: params.update({"save_img": x}), save_img, None) save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
address.change(lambda x: params.update({"address": x}), address, None)
address.submit(fn=SD_api_address_update, inputs=address, outputs=address)
prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None) prompt_prefix.change(lambda x: params.update({"prompt_prefix": x}), prompt_prefix, None)
negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None) negative_prompt.change(lambda x: params.update({"negative_prompt": x}), negative_prompt, None)
dimensions.change(lambda x: params.update({"side_length": x}), dimensions, None) width.change(lambda x: params.update({"width": x}), width, None)
# model.change(lambda x: params.update({"SD_model": x}), model, None) height.change(lambda x: params.update({"height": x}), height, None)
force_btn.click(force_pic) sampler_name.change(lambda x: params.update({"sampler_name": x}), sampler_name, None)
generate_now_btn.click(force_pic) steps.change(lambda x: params.update({"steps": x}), steps, None)
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) seed.change(lambda x: params.update({"seed": x}), seed, None)
cfg_scale.change(lambda x: params.update({"cfg_scale": x}), cfg_scale, None)
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)

View file

@ -4,14 +4,7 @@ import torch
from peft import PeftModel from peft import PeftModel
import modules.shared as shared import modules.shared as shared
from modules.models import load_model from modules.models import reload_model
from modules.text_generation import clear_torch_cache
def reload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
def add_lora_to_model(lora_name): def add_lora_to_model(lora_name):

View file

@ -1,3 +1,4 @@
import gc
import json import json
import os import os
import re import re
@ -16,11 +17,10 @@ import modules.shared as shared
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.flexgen: if shared.args.flexgen:
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
local_rank = None
if shared.args.deepspeed: if shared.args.deepspeed:
import deepspeed import deepspeed
from transformers.deepspeed import (HfDeepSpeedConfig, from transformers.deepspeed import (HfDeepSpeedConfig,
@ -182,6 +182,23 @@ def load_model(model_name):
return model, tokenizer return model, tokenizer
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()
def unload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
def reload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
def load_soft_prompt(name): def load_soft_prompt(name):
if name == 'None': if name == 'None':
shared.soft_prompt = False shared.soft_prompt = False

View file

@ -1,4 +1,3 @@
import gc
import re import re
import time import time
import traceback import traceback
@ -12,7 +11,7 @@ from modules.callbacks import (Iteratorize, Stream,
_SentinelTokenStoppingCriteria) _SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank from modules.models import clear_torch_cache, local_rank
def get_max_prompt_length(tokens): def get_max_prompt_length(tokens):
@ -101,12 +100,6 @@ def formatted_outputs(reply, model_name):
return reply return reply
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()
def set_manual_seed(seed): def set_manual_seed(seed):
if seed != -1: if seed != -1:
torch.manual_seed(seed) torch.manual_seed(seed)

View file

@ -18,9 +18,8 @@ import modules.extensions as extensions_module
from modules import api, chat, shared, training, ui from modules import api, chat, shared, training, ui
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import (clear_torch_cache, generate_reply, from modules.text_generation import generate_reply, stop_everything_event
stop_everything_event)
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
@ -79,11 +78,6 @@ def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def unload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model