diff --git a/server.py b/server.py index 4154fe44..3070455a 100644 --- a/server.py +++ b/server.py @@ -12,8 +12,8 @@ from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2 #model_name = 'gpt-j-6B-float16' #model_name = "opt-6.7b" #model_name = 'opt-13b' -#model_name = "gpt4chan_model_float16" -model_name = 'galactica-6.7b' +model_name = "gpt4chan_model_float16" +#model_name = 'galactica-6.7b' #model_name = 'gpt-neox-20b' #model_name = 'flan-t5' #model_name = 'OPT-13B-Erebus' @@ -24,17 +24,24 @@ def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() + # Loading the model if os.path.exists(f"torch-dumps/{model_name}.pt"): print("Loading in .pt format...") model = torch.load(f"torch-dumps/{model_name}.pt").cuda() - elif model_name.lower().startswith(('gpt-neo', 'opt-')): - model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True) + elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')): + if any(size in model_name for size in ('13b', '20b', '30b')): + model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True) + else: + model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() elif model_name in ['gpt-j-6B']: model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() elif model_name in ['flan-t5', 't5-large']: model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda() + else: + model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() - if model_name in ['gpt4chan_model_float16']: + # Loading the tokenizer + if model_name.startswith('gpt4chan'): tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/") elif model_name in ['flan-t5']: tokenizer = T5Tokenizer.from_pretrained(f"models/{model_name}/") diff --git a/torch-dumps/place-your-pt-models-here.txt b/torch-dumps/place-your-pt-models-here.txt deleted file mode 100644 index e69de29b..00000000