From 18ae08ef9176b9f9e460d2e8a396127cc5f2422a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 10 Jan 2023 23:41:35 -0300 Subject: [PATCH] Remove T5 support --- convert-to-torch.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/convert-to-torch.py b/convert-to-torch.py index 7dfe42fd..ab07bbcf 100644 --- a/convert-to-torch.py +++ b/convert-to-torch.py @@ -7,7 +7,7 @@ python convert-to-torch.py models/opt-1.3b The output will be written to torch-dumps/name-of-the-model.pt ''' -from transformers import AutoModelForCausalLM, T5ForConditionalGeneration +from transformers import AutoModelForCausalLM import torch from sys import argv from pathlib import Path @@ -16,10 +16,7 @@ path = Path(argv[1]) model_name = path.name print(f"Loading {model_name}...") -if model_name in ['flan-t5', 't5-large']: - model = T5ForConditionalGeneration.from_pretrained(path).cuda() -else: - model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() +model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() print("Model loaded.") print(f"Saving to torch-dumps/{model_name}.pt")