Load default model with --model flag

This commit is contained in:
oobabooga 2023-01-06 19:56:44 -03:00
parent ec2973f596
commit f54a13929f

View file

@ -2,23 +2,19 @@ import os
import re
import time
import glob
from sys import exit
import torch
import argparse
import gradio as gr
import transformers
from transformers import AutoTokenizer
from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
#model_name = "bloomz-7b1-p3"
#model_name = 'gpt-j-6B-float16'
#model_name = "opt-6.7b"
#model_name = 'opt-13b'
model_name = "gpt4chan_model_float16"
#model_name = 'galactica-6.7b'
#model_name = 'gpt-neox-20b'
#model_name = 'flan-t5'
#model_name = 'OPT-13B-Erebus'
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='Name of the model to load by default')
args = parser.parse_args()
loaded_preset = None
available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]"))))
def load_model(model_name):
print(f"Loading {model_name}...")
@ -85,7 +81,24 @@ def generate_reply(question, temperature, max_length, inference_settings, select
return reply
# Choosing the default model
if args.model is not None:
model_name = args.model
else:
if len(available_models == 0):
print("No models are available! Please download at least one.")
exit(0)
elif len(available_models) == 1:
i = 0
else:
print("The following models are available:\n")
for i,model in enumerate(available_models):
print(f"{i+1}. {model}")
print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
i = int(input())-1
model_name = available_models[i]
model, tokenizer = load_model(model_name)
if model_name.startswith('gpt4chan'):
default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
else:
@ -98,7 +111,7 @@ interface = gr.Interface(
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
gr.Dropdown(choices=sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*") + glob.glob("torch-dumps/*")))), value=model_name),
gr.Dropdown(choices=available_models, value=model_name),
],
outputs=[
gr.Textbox(placeholder="", lines=15),