Extensions performance & memory optimisations

Reworked remove_surrounded_chars() to use regular expression ( https://regexr.com/7alb5 ) instead of repeated string concatenations for elevenlab_tts, silero_tts, sd_api_pictures. This should be both faster and more robust in handling asterisks.

Reduced the memory footprint of send_pictures and sd_api_pictures by scaling the images in the chat to 300 pixels max-side wise. (The user already has the original in case of the sent picture and there's an option to save the SD generation).
This should fix history growing annoyingly large with multiple pictures present
This commit is contained in:
Φφ 2023-03-22 07:47:54 +03:00
parent 45b7e53565
commit 5389fce8e1
4 changed files with 39 additions and 27 deletions

View file

@ -4,6 +4,8 @@ import gradio as gr
from elevenlabslib import ElevenLabsUser from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path from elevenlabslib.helpers import save_bytes_to_path
import re
import modules.shared as shared import modules.shared as shared
params = { params = {
@ -52,14 +54,10 @@ def refresh_voices():
return return
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
new_string = "" # regexp is way faster than repeated string concatenation!
in_star = False # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
for char in string: # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
if char == '*': return re.sub('\*[^\*]*?(\*|$)','',string)
in_star = not in_star
elif not in_star:
new_string += char
return new_string
def input_modifier(string): def input_modifier(string):
""" """

View file

@ -1,5 +1,6 @@
import base64 import base64
import io import io
import re
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
@ -31,14 +32,10 @@ picture_response = False # specifies if the next model response should appear as
pic_id = 0 pic_id = 0
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
new_string = "" # regexp is way faster than repeated string concatenation!
in_star = False # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
for char in string: # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
if char == '*': return re.sub('\*[^\*]*?(\*|$)','',string)
in_star = not in_star
elif not in_star:
new_string += char
return new_string
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string): def input_modifier(string):
@ -54,6 +51,8 @@ def input_modifier(string):
mediums = ['image', 'pic', 'picture', 'photo'] mediums = ['image', 'pic', 'picture', 'photo']
subjects = ['yourself', 'own'] subjects = ['yourself', 'own']
lowstr = string.lower() lowstr = string.lower()
# TODO: refactor out to separate handler and also replace detection with a regexp
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
picture_response = True picture_response = True
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
@ -91,8 +90,15 @@ def get_SD_pictures(description):
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png') output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
image.save(output_file.as_posix()) image.save(output_file.as_posix())
pic_id += 1 pic_id += 1
# lower the resolution of received images for the chat, otherwise the history size gets out of control quickly with all the base64 values # lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
newsize = (300, 300) width, height = image.size
if (width > 300):
height = int(height * (300 / width))
width = 300
elif (height > 300):
width = int(width * (300 / height))
height = 300
newsize = (width, height)
image = image.resize(newsize, Image.LANCZOS) image = image.resize(newsize, Image.LANCZOS)
buffered = io.BytesIO() buffered = io.BytesIO()
image.save(buffered, format="JPEG") image.save(buffered, format="JPEG")

View file

@ -4,6 +4,7 @@ from io import BytesIO
import gradio as gr import gradio as gr
import torch import torch
from transformers import BlipForConditionalGeneration, BlipProcessor from transformers import BlipForConditionalGeneration, BlipProcessor
from PIL import Image
import modules.chat as chat import modules.chat as chat
import modules.shared as shared import modules.shared as shared
@ -25,10 +26,20 @@ def caption_image(raw_image):
def generate_chat_picture(picture, name1, name2): def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
width, height = picture.size
if (width > 300):
height = int(height * (300 / width))
width = 300
elif (height > 300):
width = int(width * (300 / height))
height = 300
newsize = (width, height)
picture = picture.resize(newsize, Image.LANCZOS)
buffer = BytesIO() buffer = BytesIO()
picture.save(buffer, format="JPEG") picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">' visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
return text, visible_text return text, visible_text
def ui(): def ui():

View file

@ -3,6 +3,7 @@ from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
import re
import modules.chat as chat import modules.chat as chat
import modules.shared as shared import modules.shared as shared
@ -46,14 +47,10 @@ def load_model():
model = load_model() model = load_model()
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
new_string = "" # regexp is way faster than repeated string concatenation!
in_star = False # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
for char in string: # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
if char == '*': return re.sub('\*[^\*]*?(\*|$)','',string)
in_star = not in_star
elif not in_star:
new_string += char
return new_string
def remove_tts_from_history(name1, name2): def remove_tts_from_history(name1, name2):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(shared.history['internal']):