Make the gallery extension work on colab

This commit is contained in:
oobabooga 2023-02-26 12:37:26 -03:00
parent 756cba2edc
commit 3333f94c30
2 changed files with 28 additions and 12 deletions

View file

@ -4,6 +4,8 @@ from pathlib import Path
import gradio as gr
from modules.html_generator import image_to_base64
def generate_html():
css = """
@ -29,10 +31,14 @@ def generate_html():
background-color: gray;
}
.character-gallery td {
text-align: center;
vertical-align: middle;
.character-gallery .image-td {
width: 150px;
}
.character-gallery .character-td {
text-align: center !important;
}
"""
table_html = f'<style>{css}</style><div class="character-gallery"><table>'
@ -41,15 +47,25 @@ def generate_html():
for file in Path("characters").glob("*"):
if file.name.endswith(".json"):
json_name = file.name
image_name = file.name.replace(".json", "")
character = file.name.replace(".json", "")
table_html += "<tr>"
if Path(f"characters/{image_name}.png").exists():
image_html = f'<img src="file/characters/{image_name}.png">'
elif Path(f"characters/{image_name}.jpg").exists():
image_html = f'<img src="file/characters/{image_name}.jpg">'
else:
image_html = "<div class='placeholder'></div>"
table_html += f"<td>{image_html}</td><td>{image_name}</td>"
image_html = "<div class='placeholder'></div>"
for i in [
f"characters/{character}.png",
f"characters/{character}.jpg",
f"characters/{character}.jpeg",
]:
path = Path(i)
if path.exists():
try:
image_html = f'<img src="data:image/png;base64,{image_to_base64(path)}">'
break
except:
continue
table_html += f'<td class="image-td"=>{image_html}</td><td class="character-td">{character}</td>'
table_html += "</tr>"
table_html += "</table></div>"

View file

@ -200,7 +200,7 @@ def image_to_base64(path):
if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
img = Image.open(path)
img.thumbnail((100, 100))
img.thumbnail((200, 200))
img_buffer = BytesIO()
img.convert('RGB').save(img_buffer, format='PNG')
image_cache[path] = [mtime, base64.b64encode(img_buffer.getvalue()).decode("utf-8")]