Add HTML support for gpt4chan

This commit is contained in:
oobabooga 2023-01-06 23:14:08 -03:00
parent 3d6a3aac73
commit c7b29668a2
2 changed files with 170 additions and 6 deletions

154
html_generator.py Normal file
View file

@ -0,0 +1,154 @@
'''
This is a library for formatting gpt4chan outputs as nice HTML.
'''
import re
def process_post(post, c):
t = post.split('\n')
number = t[0].split(' ')[1]
if len(t) > 1:
src = '\n'.join(t[1:])
else:
src = ''
src = re.sub('>', '>', src)
src = re.sub('(&gt;&gt;[0-9]*)', '<span class="quote">\\1</span>', src)
src = re.sub('\n', '<br>\n', src)
src = f'<blockquote class="message">{src}\n'
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
return src
def generate_html(f):
css = """
#container {
background-color: #eef2ff;
padding: 17px;
}
.reply {
background-color: rgb(214, 218, 240);
border-bottom-color: rgb(183, 197, 217);
border-bottom-style: solid;
border-bottom-width: 1px;
border-image-outset: 0;
border-image-repeat: stretch;
border-image-slice: 100%;
border-image-source: none;
border-image-width: 1;
border-left-color: rgb(0, 0, 0);
border-left-style: none;
border-left-width: 0px;
border-right-color: rgb(183, 197, 217);
border-right-style: solid;
border-right-width: 1px;
border-top-color: rgb(0, 0, 0);
border-top-style: none;
border-top-width: 0px;
color: rgb(0, 0, 0);
display: table;
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
margin-bottom: 4px;
margin-left: 0px;
margin-right: 0px;
margin-top: 4px;
overflow-x: hidden;
overflow-y: hidden;
padding-bottom: 2px;
padding-left: 2px;
padding-right: 2px;
padding-top: 2px;
}
.number {
color: rgb(0, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
width: 342.65px;
}
.op {
color: rgb(0, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
margin-bottom: 8px;
margin-left: 0px;
margin-right: 0px;
margin-top: 4px;
overflow-x: hidden;
overflow-y: hidden;
}
.op blockquote {
margin-left:7px;
}
.name {
color: rgb(17, 119, 67);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
font-weight: 700;
margin-left: 7px;
}
.quote {
color: rgb(221, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
text-decoration-color: rgb(221, 0, 0);
text-decoration-line: underline;
text-decoration-style: solid;
text-decoration-thickness: auto;
}
.greentext {
color: rgb(120, 153, 34);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
}
blockquote {
margin-block-start: 1em;
margin-block-end: 1em;
margin-inline-start: 40px;
margin-inline-end: 40px;
}
"""
posts = []
post = ''
c = -2
for line in f.splitlines():
line += "\n"
if line == '-----\n':
continue
elif line.startswith('--- '):
c += 1
if post != '':
src = process_post(post, c)
posts.append(src)
post = line
else:
post += line
if post != '':
src = process_post(post, c)
posts.append(src)
for i in range(len(posts)):
if i == 0:
posts[i] = f'<div class="op">{posts[i]}</div>\n'
else:
posts[i] = f'<div class="reply">{posts[i]}</div>\n'
output = ''
output += f'<style>{css}</style><div id="container">'
for post in posts:
output += post
output += '</div>'
output = output.split('\n')
for i in range(len(output)):
output[i] = re.sub('^(&gt;[^\n]*(<br>|</div>))', '<span class="greentext">\\1</span>\n', output[i])
output = '\n'.join(output)
return output

View file

@ -7,15 +7,19 @@ import torch
import argparse import argparse
import gradio as gr import gradio as gr
import transformers import transformers
from html_generator import *
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.')
args = parser.parse_args() args = parser.parse_args()
loaded_preset = None loaded_preset = None
available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]")))) available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*")+ glob.glob("torch-dumps/*"))))
available_models = [item for item in available_models if not item.endswith('.txt')]
#available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]"))))
def load_model(model_name): def load_model(model_name):
print(f"Loading {model_name}...") print(f"Loading {model_name}...")
@ -75,15 +79,17 @@ def generate_reply(question, temperature, max_length, inference_settings, select
input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda() input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
output = eval(f"model.generate(input_ids, {preset}).cuda()") output = eval(f"model.generate(input_ids, {preset}).cuda()")
reply = tokenizer.decode(output[0], skip_special_tokens=True) reply = tokenizer.decode(output[0], skip_special_tokens=True)
if model_name.startswith('gpt4chan'): if model_name.startswith('gpt4chan'):
reply = fix_gpt4chan(reply) reply = fix_gpt4chan(reply)
if model_name.lower().startswith('galactica'): if model_name.lower().startswith('galactica'):
return reply, reply return reply, reply, 'Only applicable for gpt4chan.'
elif model_name.lower().startswith('gpt4chan'):
return reply, 'Only applicable for galactica models.', generate_html(reply)
else: else:
return reply, 'Only applicable for galactica models.' return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'
# Choosing the default model # Choosing the default model
if args.model is not None: if args.model is not None:
@ -121,6 +127,8 @@ if args.notebook:
textbox = gr.Textbox(value=default_text, lines=23) textbox = gr.Textbox(value=default_text, lines=23)
with gr.Tab('Markdown'): with gr.Tab('Markdown'):
markdown = gr.Markdown() markdown = gr.Markdown()
with gr.Tab('HTML'):
html = gr.HTML()
btn = gr.Button("Generate") btn = gr.Button("Generate")
with gr.Row(): with gr.Row():
@ -131,7 +139,7 @@ if args.notebook:
preset_menu = gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default", label='Preset') preset_menu = gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default", label='Preset')
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [textbox, markdown], show_progress=False) btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=False)
else: else:
with gr.Blocks() as interface: with gr.Blocks() as interface:
gr.Markdown( gr.Markdown(
@ -154,7 +162,9 @@ else:
output_textbox = gr.Textbox(value=default_text, lines=15, label='Output') output_textbox = gr.Textbox(value=default_text, lines=15, label='Output')
with gr.Tab('Markdown'): with gr.Tab('Markdown'):
markdown = gr.Markdown() markdown = gr.Markdown()
with gr.Tab('HTML'):
html = gr.HTML()
btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown], show_progress=True) btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
interface.launch(share=False, server_name="0.0.0.0") interface.launch(share=False, server_name="0.0.0.0")